Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import pytest
from packaging.version import Version
from transformers import PretrainedConfig
from transformers import __version__ as TRANSFORMERS_VERSION

from vllm.config.model import ModelDType, TokenizerMode
Expand Down Expand Up @@ -985,7 +986,26 @@ def check_available_online(
trust_remote_code=True,
),
"NemotronH_Nano_VL_V2": _HfExamplesInfo(
"nano_vl_dummy", is_available_online=False, trust_remote_code=True
"nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16",
max_model_len=4096,
# NemotronH layers are constructed via `hybrid_override_pattern`:
use_original_num_layers=True,
hf_overrides={
"vision_config": PretrainedConfig(
args={
"min_num_patches": 1, # Trigger image dynamic res
"max_num_patches": 12,
"model": "vit_huge_patch16_224",
},
# Trigger conv3d:
video_temporal_patch_size=2,
),
"text_config": {
"num_hidden_layers": 2,
"hybrid_override_pattern": "M*",
},
},
trust_remote_code=True,
),
"OpenCUAForConditionalGeneration": _HfExamplesInfo(
"xlangai/OpenCUA-7B", trust_remote_code=True
Expand Down
9 changes: 8 additions & 1 deletion tests/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,9 +447,16 @@ def dummy_hf_overrides(
Dummy HF overrides function used to create dummy model
with only minimum nums of layer.
"""
hf_config.update(exist_overrides or {})
# Copy because this helper is called more than once
# while loading config, and we `.pop()`
exist_overrides = (exist_overrides or {}).copy()
text_config_override = exist_overrides.pop("text_config", None)
hf_config.update(exist_overrides)

text_config = hf_config.get_text_config()
if text_config_override is not None:
# multimodal test models may override *some* text-model fields
text_config.update(text_config_override)

# Ensure at least 2 expert per group
# Since `grouped_topk` assumes top-2
Expand Down
52 changes: 29 additions & 23 deletions vllm/model_executor/models/nano_nemotron_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
# LICENSE is in root directory.
# --------------------------------------------------------

import copy
import math
import warnings
from collections.abc import Iterable, Mapping, Sequence
Expand All @@ -17,7 +16,7 @@

import torch
import torch.nn as nn
from transformers import BatchFeature
from transformers import BatchFeature, PretrainedConfig

from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
Expand Down Expand Up @@ -210,11 +209,15 @@ def get_hf_processor(self, **kwargs: object) -> NanoNemotronVLProcessor:

@cached_property
def is_dynamic_tiler(self) -> bool:
return self.get_hf_processor().dynamic_tiler is not None
return BaseNanoNemotronVLProcessor.use_dynamic_resolution(self.get_hf_config())

@cached_property
@property
def supports_video(self):
return self.get_hf_processor().supports_video
return True

@property
def supports_audio(self) -> bool:
return self.sound_config is not None

def get_video_token(self) -> str | None:
return IMG_CONTEXT
Expand All @@ -223,23 +226,23 @@ def get_video_pruning_rate(self) -> float | None:
return self.ctx.get_mm_config().video_pruning_rate

@property
def audio_extractor(self) -> ParakeetExtractor | None:
return self.get_hf_processor().audio_extractor
def sound_config(self) -> PretrainedConfig | None:
return getattr(self.get_hf_config(), "sound_config", None)

def get_default_tok_params(self) -> TokenizeParams:
return super().get_default_tok_params().with_kwargs(add_special_tokens=False)

def get_supported_mm_limits(self) -> Mapping[str, int | None]:
image_limit = {"image": None}
video_limit = {"video": None} if self.supports_video else {}
audio_limit = {"audio": None} if self.audio_extractor is not None else {}
audio_limit = {"audio": None} if self.supports_audio else {}
return {**image_limit, **video_limit, **audio_limit}

def get_data_parser(self):
target_sr = None
target_channels = None
if extractor := self.audio_extractor:
target_sr = extractor.sampling_rate
if self.sound_config:
target_sr = self.sound_config.sampling_rate
target_channels = 1

return MultiModalDataParser(
Expand Down Expand Up @@ -371,7 +374,7 @@ def _get_mm_fields_config(
fields = self._get_image_fields_config(hf_inputs)
if self.info.supports_video:
fields |= self._get_video_fields_config(hf_inputs)
if self.info.audio_extractor:
if self.info.supports_audio:
fields |= self._get_audio_fields_config(hf_inputs)

return fields
Expand Down Expand Up @@ -399,9 +402,8 @@ def get_image_replacement(item_idx: int):

if isinstance(images, ImageEmbeddingItems):
feature_size = images.get_feature_size(item_idx)
elif tiler := hf_processor.dynamic_tiler:
image = images.get(item_idx)
feature_size = tiler.get_cached_feature_size(image)
elif self.info.is_dynamic_tiler:
feature_size = out_mm_data["num_tokens_per_image"][item_idx]
else:
image_size = images.get_image_size(item_idx)
max_num_tiles = hf_processor.max_num_tiles
Expand Down Expand Up @@ -536,7 +538,7 @@ def _get_prompt_updates(
prompt_repls.append(
self._get_prompt_repl_video(mm_items, hf_processor, out_mm_data)
)
if self.info.audio_extractor:
if self.info.supports_audio:
prompt_repls.append(
self._get_prompt_repl_audio(mm_items, hf_processor, out_mm_data)
)
Expand Down Expand Up @@ -772,12 +774,14 @@ def get_dummy_mm_data(
else:
dummy_video = {}

if extractor := self.info.audio_extractor:
if sound_config := self.info.sound_config:
num_audios = mm_counts.get("audio", 0)
audio_overrides = mm_options.get("audio") if mm_options else None
tokens_per_audio = max(1, seq_len // max(num_audios, 1))
max_audio_num_samples = MAX_AUDIO_LEN_S * extractor.sampling_rate
calculated_max_audio_num_samples = extractor.audio_length(tokens_per_audio)
max_audio_num_samples = MAX_AUDIO_LEN_S * sound_config.sampling_rate
calculated_max_audio_num_samples = ParakeetExtractor.audio_length(
sound_config, tokens_per_audio
)
audio_len = min(max_audio_num_samples, calculated_max_audio_num_samples)
dummy_audio = {
"audio": self._get_dummy_audios(
Expand Down Expand Up @@ -1029,9 +1033,13 @@ def _parse_and_validate_image_input(
data=image_embeds,
)

pixel_values_flat = kwargs.pop("pixel_values_flat", None)
if pixel_values_flat is None:
return None

if self.dynamic_resolution:
pixel_values_flat = DynamicResolutionImageTiler.stack(
kwargs.pop("pixel_values_flat"), self.patch_size
pixel_values_flat, self.patch_size
)
return NanoNemotronVLImagePixelInputsDynamic(
pixel_values_flat=pixel_values_flat, **kwargs
Expand Down Expand Up @@ -1498,15 +1506,13 @@ def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
@classmethod
def get_mamba_state_shape_from_config(cls, vllm_config: "VllmConfig"):
text_config = vllm_config.model_config.hf_config.text_config
temp_vllm_config = copy.deepcopy(vllm_config)
temp_vllm_config.model_config.hf_config = text_config
temp_vllm_config = vllm_config.with_hf_config(text_config)
return NemotronHForCausalLM.get_mamba_state_shape_from_config(temp_vllm_config)

@classmethod
def get_mamba_state_dtype_from_config(cls, vllm_config: "VllmConfig"):
text_config = vllm_config.model_config.hf_config.text_config
temp_vllm_config = copy.deepcopy(vllm_config)
temp_vllm_config.model_config.hf_config = text_config
temp_vllm_config = vllm_config.with_hf_config(text_config)
return NemotronHForCausalLM.get_mamba_state_dtype_from_config(temp_vllm_config)

@classmethod
Expand Down
6 changes: 4 additions & 2 deletions vllm/model_executor/models/parakeet.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,5 +159,7 @@ def __call__(self, raw_speech: list[np.ndarray], *args, **kwargs):
outputs["audio_num_clips"] = audio_num_clips
return outputs

def audio_length(self, audio_tokens: int) -> int:
return int(audio_tokens * self.config.subsampling_factor * self.hop_length)
@staticmethod
def audio_length(raw_config: PretrainedConfig, audio_tokens: int) -> int:
config = ExtractorConfig.from_hf_config(raw_config)
return int(audio_tokens * config.subsampling_factor * config.hop_length)
11 changes: 1 addition & 10 deletions vllm/model_executor/models/radio.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ def __init__(
temporal_patch_size=temporal_patch_size,
**factory,
)
self._video_embedder_loaded = False

if abs_pos:
scale = embed_dim**-0.5
Expand Down Expand Up @@ -225,12 +224,7 @@ def forward_video(self, x: torch.Tensor) -> torch.Tensor:
Returns:
Embedded patches with temporal compression applied.
"""
if not self._video_embedder_loaded:
raise ValueError(
"Temporal compression (video_temporal_patch_size > 1) requires "
"video_embedder weights, but they were never loaded. "
"Ensure the checkpoint was trained with temporal compression."
)
assert self.temporal_patch_size > 1
T = self.temporal_patch_size
input_size = x.shape[2:]

Expand Down Expand Up @@ -794,9 +788,6 @@ def load_weights(self, weights) -> set[str]:
weight_loader(param, weight)
loaded_params.add(vllm_key)

if "model.patch_generator.video_embedder.weight" in loaded_params:
self.model.patch_generator._video_embedder_loaded = True

return loaded_params

def _extract_final(
Expand Down
4 changes: 4 additions & 0 deletions vllm/transformers_utils/configs/parakeet.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,19 @@ class ExtractorConfig:
subsampling_factor: int
subsampling_conv_kernel_size: int
subsampling_conv_stride: int
hop_length: int = 160
"""Default `160`: Matches HF default"""
clip_duration_s: int = 30
clip_min_duration_s: float = 0.1

@staticmethod
def from_hf_config(config: PretrainedConfig) -> "ExtractorConfig":
assert isinstance(config, PretrainedConfig)
hop_length = int(getattr(config, "hop_length", ExtractorConfig.hop_length))
return ExtractorConfig(
feature_size=config.num_mel_bins,
sampling_rate=config.sampling_rate,
hop_length=hop_length,
subsampling_factor=config.subsampling_factor,
subsampling_conv_kernel_size=config.subsampling_conv_kernel_size,
subsampling_conv_stride=config.subsampling_conv_stride,
Expand Down
32 changes: 17 additions & 15 deletions vllm/transformers_utils/processors/nano_nemotron_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,15 +356,6 @@ def _images_to_pixel_values_lst(
feature_sizes.append(param.num_embeddings)
return images, feature_sizes

feature_size_cache: dict[Image.Image, int] = {}

@classmethod
def get_cached_feature_size(cls, image: Image.Image) -> int:
feature_size = cls.feature_size_cache[id(image)]
# hard assert that we only use the feature size once
del cls.feature_size_cache[id(image)]
return feature_size

@dataclass
class DynamicResolutionParams:
media: Image.Image
Expand Down Expand Up @@ -519,7 +510,6 @@ def compute_params(
param, token_count = self.process_media(media, tokens_for_media)
params.append(param)
token_counts.append(token_count)
self.feature_size_cache[id(param.media)] = param.num_embeddings

# Step 2: Check if total tokens is within budget
total_tokens = sum(token_counts)
Expand Down Expand Up @@ -857,13 +847,12 @@ def num_video_token(self) -> int:

@property
def supports_video(self) -> bool:
return self.video_token_id is not None
return True

@property
def video_token_id(self) -> int | None:
if self.video_token is None:
return None
return self.tokenizer.get_vocab().get(self.video_token, None)
def video_token_id(self) -> int:
assert self.video_token is not None
return self.tokenizer.get_vocab()[self.video_token]

@property
def image_token_id(self) -> int:
Expand Down Expand Up @@ -1055,6 +1044,13 @@ def __call__(
text_inputs = self.tokenizer(text, add_special_tokens=False)

combined_inputs = {**text_inputs, **video_inputs, **audio_inputs}
frames_indices = combined_inputs.get("frames_indices")
ragged_frames_indices = (
isinstance(frames_indices, list)
and len({len(frame_indices) for frame_indices in frames_indices}) > 1
)
if ragged_frames_indices:
combined_inputs.pop("frames_indices")

if self.dynamic_tiler is None:
batch = BatchFeature(
Expand All @@ -1066,6 +1062,12 @@ def __call__(
# allow images to be exempt from the BatchFeature validation:
# We will .stack() them in _parse_and_validate_image_input
batch.update(image_inputs)
if ragged_frames_indices:
assert isinstance(frames_indices, list)
batch["frames_indices"] = [
torch.as_tensor(frame_indices, dtype=torch.int64)
for frame_indices in frames_indices
]
return batch

def get_image_repl(
Expand Down
Loading