Skip to content

Conversation

@ch3nku1
Copy link

@ch3nku1 ch3nku1 commented Dec 15, 2025

I am a developer from Cybercore. We are developing a multimodal model named Leum and plan to deploy it using the vLLM transformers modeling backend. We noticed that the current implementation does not support video input, which is a necessary feature for our model. This pull request introduces the required changes to enable video input processing.

Key changes:

  • Extended multimodal classes (MultiModalProcessingInfo, MultiModalProcessor, MultiModalMixin) to handle video-specific logic, including token calculation, dummy data generation, and embedding.
  • Corrected the frame size extraction for video frames in vllm/multimodal/parse.py.
  • Updated documentation to reflect video support.
  • Fixed a potential OOM issue in the dummy batch generator for multimodal models.

Thank you for considering our contribution!

…ckend

Key changes:
- Extended multimodal classes (`MultiModalProcessingInfo`, `MultiModalProcessor`, `MultiModalMixin`) to handle video-specific logic, including token calculation, dummy data generation, and embedding.
- Corrected the frame size extraction for video frames in `vllm/multimodal/parse.py`.
- Updated documentation to reflect video support.
- Fixed a potential OOM issue in the dummy batch generator for multimodal models.

Signed-off-by: chenkui.shen <[email protected]>
@chatgpt-codex-connector
Copy link

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

@mergify
Copy link

mergify bot commented Dec 15, 2025

Documentation preview: https://vllm--30680.org.readthedocs.build/en/30680/

@mergify mergify bot added documentation Improvements or additions to documentation multi-modality Related to multi-modality (#4194) v1 labels Dec 15, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces video input support for the transformers modeling backend. The changes are comprehensive, touching multimodal processing classes, documentation, and fixing a bug in video frame size extraction. A key contribution is a fix for an OOM issue during dummy batch generation for profiling. While the fix prevents crashes, I've identified a critical issue with the approach and suggested a more robust solution to ensure accurate memory profiling and prevent potential OOMs in production.

Comment on lines 1120 to +1132
dummy_modality = mm_budget.get_modality_with_max_tokens()
return self._get_mm_dummy_batch(dummy_modality, num_seqs)

# TBD:
# The mm_dummy_batch below is only retrieved when
# supports_multimodal_raw_input_only is True.
# Currently, only the transform modeling backend and terratorch have
# supports_multimodal_raw_input_only as True.
# When testing the transform modeling backend, it was found that
# if num_seqs (usually the default 256) is passed in here,
# an OOM error occurs.
# It needs to be confirmed what value should be passed in here,
# for now it is fixed to 1.
return self._get_mm_dummy_batch(dummy_modality, 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Hardcoding the dummy batch size to 1 for profiling is a pragmatic fix to avoid OOM errors, but it can lead to inaccurate memory profiling. This might cause the scheduler to underestimate memory usage, potentially leading to OOM errors in production with larger batches.

A more robust approach would be to calculate a reasonable batch size based on the model's configuration. This provides a more realistic batch size for profiling, reducing the risk of production OOMs while still preventing OOMs during profiling.

        dummy_modality = mm_budget.get_modality_with_max_tokens()
        max_tokens_per_item = mm_budget.max_tokens_by_modality.get(dummy_modality)

        if max_tokens_per_item and max_tokens_per_item > 0:
            # Heuristic to derive a reasonable batch size for profiling.
            # Using max_num_seqs can cause OOM for vision models.
            # Hardcoding to 1 can lead to inaccurate profiling.
            num_items = self.scheduler_config.max_num_batched_tokens // max_tokens_per_item
            # Also respect the per-prompt limit and max sequences.
            max_items_for_modality = mm_budget.max_items_per_batch_by_modality[dummy_modality]
            num_items = min(num_items, max_items_for_modality)
            # Ensure at least 1 item.
            num_items = max(num_items, 1)
        else:
            # Fallback for safety, though this path should ideally not be taken.
            num_items = 1

        return self._get_mm_dummy_batch(dummy_modality, num_items)

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@hmellor hmellor self-assigned this Dec 15, 2025
Copy link
Member

@hmellor hmellor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome PR! Thank you for using and contributing to the Transformers modelling backend!

I've left a few comments and looped in @zucchini-nlp. I want to make sure that the standards we're defining here for video model interfaces align with how Transformers would like to standardise all the video models in the Transformers library.

Comment on lines +1122 to +1132
# TBD:
# The mm_dummy_batch below is only retrieved when
# supports_multimodal_raw_input_only is True.
# Currently, only the transform modeling backend and terratorch have
# supports_multimodal_raw_input_only as True.
# When testing the transform modeling backend, it was found that
# if num_seqs (usually the default 256) is passed in here,
# an OOM error occurs.
# It needs to be confirmed what value should be passed in here,
# for now it is fixed to 1.
return self._get_mm_dummy_batch(dummy_modality, 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was recently working in this area and noticed this too. It causes 256 x 100MP images (the max image size defined in vllm/model_executor/models/transformers/multimodal.py) to be materialised on the GPU.

@ywang96 what do you think should be the correct behaviour here?

Copy link
Member

@DarkLight1337 DarkLight1337 Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think here the profiling is actually correct for models that have is_multimodal_raw_input_only_model, since the maximum number of videos possible during inference is indeed based on max_num_seqs (and can be actually more than that if the model accepts more than 1 video per prompt).

Copy link
Member

@DarkLight1337 DarkLight1337 Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I prefer patching this on the model side. If removing supports_multimodal_raw_input_only=True is not feasible, then maybe you can override the mm_counts in get_dummy_mm_data to use the same value as for regular MM models.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forget the exact reason we chose to use supports_multimodal_raw_input_only=True. IIRC it was because it reduced the amount of necessary monkey patching because the processor already exists on the Transformers side so it's needlessly complicated to monkey patch the processing on the vLLM side.

@DarkLight1337 do you know more about the key differences between supports_multimodal_raw_input_only being True/False?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since the maximum number of videos possible during inference is indeed based on max_num_seqs

Why is this not also the case when is_multimodal_raw_input_only_model=False?

return ImageSize(*image.size)
if isinstance(image, (np.ndarray, torch.Tensor)):
_, h, w = image.shape
w, h, _ = image.shape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubt that this was wrong for all models already in vLLM, why does this need changing here?

processor = self.info.get_hf_processor()
if "gemma3" in processor.__class__.__name__.lower():
image_token = processor.boi_token
video_token = ""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
video_token = ""

else:
image_token = getattr(processor, "image_token", "")
return image_token * num_images
video_token = getattr(processor, "video_token", "")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that Gemma3 will have no video_token this should be fine, right?

Suggested change
video_token = getattr(processor, "video_token", "")
video_token = getattr(processor, "video_token", "")

Comment on lines +151 to +156
"video": self._get_dummy_videos(
width=target_width,
height=target_height,
num_frames=target_num_frames,
num_videos=num_videos,
),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also provide video overrides, right?

kwargs.pop("token_type_ids", None) # used only in `forward`

if pixel_values is not None:
num_image_patches = kwargs.pop("num_image_patches")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was outside the if block so that it was always popped from kwargs regardless of if we used it. The same shold probably be done for num_video_patches

Comment on lines +450 to +452
if isinstance(vision_embeddings, tuple):
# For qwen3 vl, The deepstack visual features are also returned
vision_embeddings = vision_embeddings[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might not be intentional. @zucchini-nlp should the return value of Qwen3 VL have been changed to this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We never supportted qwen3-vl in vLLM because of this 😿 The model has to propagate deepstack visual features down to the LM's forward, which we could theoretically overcome

We are currently doing standardization of get_image_features in transformers side and I think the best would be to always ask for a dict output here (i.e. we get all the outputs from vision encoder) and pass it over to LM, let it handle encoder outputs whichever way it pleases

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean in vLLM with the Transformers modelling backend? vLLM does support this model natively.

Yeah a standardisation on the Transformers side would be great.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, with the backend the model was not supported when released

Comment on lines +467 to +495
multimodal_embeddings += tuple(vision_embeddings)

if video_embeds is not None:
multimodal_embeddings += tuple(video_embeds)

if pixel_values_videos is not None:
num_video_patches = kwargs.pop("num_video_patches")
vision_embeddings = self.model.get_video_features(
pixel_values_videos, **kwargs
)

if isinstance(vision_embeddings, tuple):
# For qwen3 vl, The deepstack visual features are also returned
vision_embeddings = vision_embeddings[0]
if isinstance(vision_embeddings, torch.Tensor):
if vision_embeddings.ndim == 2:
vision_embeddings = vision_embeddings.unsqueeze(0)

# Embeddings have to be 2D tensors of length `num_images`
# but transformers returns concat tensors if each patch
# is of different size. We split it back to make vLLM happy
vision_embeddings = torch.split(
vision_embeddings, num_video_patches.flatten().tolist()
)
vision_embeddings = [
embed.flatten(start_dim=0, end_dim=-2)
for embed in vision_embeddings
]
multimodal_embeddings += tuple(vision_embeddings)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment to above, this is very similar to the image code. Could they be deduplicated?

Comment on lines 523 to 524
image_grid_thw = kwargs.get("image_grid_thw", [])
video_grid_thw = kwargs.get("video_grid_thw", [])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we're not trying to create empty tensors with these

Suggested change
image_grid_thw = kwargs.get("image_grid_thw", [])
video_grid_thw = kwargs.get("video_grid_thw", [])
image_grid_thw = kwargs.get("image_grid_thw", None)
video_grid_thw = kwargs.get("video_grid_thw", None)

Comment on lines +526 to +527
image_grid_thw = torch.stack(image_grid_thw) if image_grid_thw else None
video_grid_thw = torch.stack(video_grid_thw) if video_grid_thw else None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zucchini-nlp was it important that vLLM passed empty grids rather than None to get_rope_index?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be None if the modality is not present

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To clarify, it should be None if:

  • the modality is not present in the individual request? (i.e. an image model but no image has been passed)
  • the modelity is not present in the model? (i.e. a model which doesn't support video)

if split_sizes:
chunked_mm_positions = torch.split(mm_positions, split_sizes)
mm_tokens = torch.tensor(prompt_ids)[mm_token_type_ids[0].bool()]
mm_tokens = torch.tensor(prompt_ids)[mm_token_type_ids[0] == 1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid computing mm_token_type_ids == 1 twice

Copy link
Member

@DarkLight1337 DarkLight1337 Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also is there somewhere in Transformers where the token type ID is defined by modality? I don't like this magic number

Copy link
Contributor

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few comments from my experience trying to make videos compatible with the backend. IMO this PR has to wait until necessary changes are made in transformers side for video LLMs. I am quite sure we can't support all of them, but we could try to support 70%

There aren't many models, around 10 I think in total with explicit video support.

Comment on lines +66 to +67
"image": self.get_max_image_tokens(),
"video": self.get_max_video_tokens(seq_len),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember that vLLM chooses only one modality when profling, whichever has more tokens. So I think we will need to safe-get mm_tokens["num_video_tokens"] and otherwise set to 0 because not all models support videos

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Qwen2-VL would be a good example of this with get_num_frames_with_most_features. But it gets its mm_counts from the limits in get_supported_mm_limits unless the user specified otherwise.

Ideally there would be a way to set get_supported_mm_limits()["video"] to zero when the model doesn't support video.

Copy link
Contributor

@zucchini-nlp zucchini-nlp Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that would be great. I even can say how we infer if model supports videos from model class 😄 It has a class attribute model.input_modalities

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh ok, are these class attributes or instance attributes?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

class attributes available from v5 and on

else:
image_token = getattr(processor, "image_token", "")
return image_token * num_images
video_token = getattr(processor, "video_token", "")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, this is the part which might fails for some models. There are models that treat videos as a sequence of images and thus don't have a specific video token 🥲 It made me crazy trying to make them work

Comment on lines +202 to +203
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
video_grid_sizes = video_grid_thw.prod(-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not super generalizable, only qwen-like models have a THW tensor. Ideally we need to use num_video_patches as in image modality


mm_tokens_per_modality = hf_processor._get_num_multimodal_tokens(
image_sizes=image_sizes, **mm_processor_kwargs
image_sizes=image_sizes, video_sizes=video_sizes, **mm_processor_kwargs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think all processor expect video_sizes, though it will be swallowed by kwargs and shouldn't raise issues

Comment on lines +326 to +331
ranges = [
PlaceholderRange(
offset=positions[0].item(),
length=positions.shape[0],
is_embed=(mm_tokens == hf_processor.video_token_id).bool(),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the hard part here are models which add timestamps between each video frame and the timestamps are encoded differently, they have no special tokens. So inferring the ranges is not as easy as with images

mm_placeholders["video"] = ranges

processed_data["num_video_patches"] = torch.tensor(
mm_tokens_per_modality["num_video_patches"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mm_tokens_per_modality isn't guaranteed to contain num_video_patches if the model doesn't support video modality. Same for key=num_video_tokens

vision_embeddings = self.model.get_image_features(pixel_values, **kwargs)

if isinstance(vision_embeddings, tuple):
# For qwen3 vl, The deepstack visual features are also returned
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, though qwen3-vl will give bad performance without deepstack visual features, Currently vLLM doesn't support models like Qwen3-VL and Ovis2 on purpose

Comment on lines +485 to +491
# Embeddings have to be 2D tensors of length `num_images`
# but transformers returns concat tensors if each patch
# is of different size. We split it back to make vLLM happy
vision_embeddings = torch.split(
vision_embeddings, num_video_patches.flatten().tolist()
)
vision_embeddings = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not 100% sure, though for video modality this might not be needed. Mostly videos don't have several patches per one frame

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation multi-modality Related to multi-modality (#4194) v1

Projects

Status: Todo

Development

Successfully merging this pull request may close these issues.

4 participants