Skip to content

Commit e71a200

Browse files
authored
[TRTLLM-9019][feat] Expose video_pruning_rate as llmargs and fix nano-v2-vl (NVIDIA#12194)
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
1 parent 355d993 commit e71a200

File tree

15 files changed

+151
-40
lines changed

15 files changed

+151
-40
lines changed

examples/llm-api/quickstart_multimodal.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,11 @@ def add_multimodal_args(parser):
132132
type=int,
133133
default=2,
134134
help="Number of conversation turns for automated testing.")
135+
parser.add_argument("--video_pruning_rate",
136+
type=float,
137+
default=None,
138+
help="Pruning rate for video frames (EVS). "
139+
"None disables EVS, values in [0, 1) enable pruning.")
135140
return parser
136141

137142

@@ -181,7 +186,9 @@ def main():
181186
lora_config.max_loras = 2
182187
lora_config.max_cpu_loras = 2
183188

184-
llm, sampling_params = setup_llm(args, lora_config=lora_config)
189+
llm, sampling_params = setup_llm(args,
190+
lora_config=lora_config,
191+
video_pruning_rate=args.video_pruning_rate)
185192

186193
image_format = args.image_format
187194
if args.model_type is not None:

examples/models/core/nemotron/README_nano-v2-vl.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ python3 examples/llm-api/quickstart_multimodal.py --model_dir nvidia/NVIDIA-Nemo
3838
* Video modality input with Efficient video sampling (EVS):
3939

4040
```bash
41-
TLLM_VIDEO_PRUNING_RATIO=0.9 python3 examples/llm-api/quickstart_multimodal.py --model_dir nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16 --disable_kv_cache_reuse --max_batch_size 128 --trust_remote_code --modality video --max_num_tokens 131072
41+
python3 examples/llm-api/quickstart_multimodal.py --model_dir nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16 --disable_kv_cache_reuse --max_batch_size 128 --trust_remote_code --modality video --max_num_tokens 131072 --video_pruning_rate 0.9
4242
```
4343

4444
## Online serving example CMDs
@@ -55,7 +55,7 @@ EOF
5555

5656
# CMD to launch serve without EVS.
5757
trtllm-serve \
58-
nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16\
58+
nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16 \
5959
--host 0.0.0.0 \
6060
--port 8000 \
6161
--backend pytorch \
@@ -65,16 +65,17 @@ nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16\
6565
--media_io_kwargs "{\"video\": {\"fps\": 2, \"num_frames\": 128} }" \
6666
--config config.yml
6767

68-
# CMD to launch serve with EVS (video_pruning_ratio=0.9).
69-
TLLM_VIDEO_PRUNING_RATIO=0.9 trtllm-serve \
70-
nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16\
68+
# CMD to launch serve with EVS (video_pruning_rate=0.9).
69+
trtllm-serve \
70+
nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16 \
7171
--host 0.0.0.0 \
7272
--port 8000 \
7373
--backend pytorch \
7474
--max_batch_size 16 \
7575
--max_num_tokens 131072 \
7676
--trust_remote_code \
7777
--media_io_kwargs "{\"video\": {\"fps\": 2, \"num_frames\": 128} }" \
78+
--video_pruning_rate 0.9 \
7879
--config config.yml
7980
```
8081

tensorrt_llm/_torch/model_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ class ModelConfig(Generic[TConfig]):
131131
# If true, ONLY the vision encoder part of the full model is loaded/executed.
132132
mm_encoder_only: bool = False
133133

134+
# Video pruning rate for VLM models (None = EVS disabled)
135+
video_pruning_rate: Optional[float] = None
136+
134137
def __setattr__(self, key, value):
135138
"""
136139
Prevent modification of frozen instance attributes.

tensorrt_llm/_torch/models/modeling_nemotron_nano.py

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
from .modeling_radio import RADIOVisionModel, calc_seq_lens
4343
from .modeling_utils import register_auto_model
4444

45-
VIDEO_PRUNING_RATIO = float(os.getenv("TLLM_VIDEO_PRUNING_RATIO", "0"))
4645
# Set max_num_tiles to 1 for video modality, to match the training behavior.
4746
VIDEO_MAX_NUM_TILES = 1
4847
IMAGE_PLACEHOLDER = "<image>"
@@ -257,7 +256,10 @@ def __init__(self, model_config: ModelConfig[transformers.PretrainedConfig]):
257256
raise NotImplementedError(
258257
f"Unsupported {config.ps_version=}. Supported versions: {supported_versions}."
259258
)
260-
self.video_pruning_ratio = VIDEO_PRUNING_RATIO
259+
# Use config value if explicitly set (EVS enabled), otherwise default to 0.0 (EVS disabled)
260+
self.video_pruning_rate = (
261+
model_config.video_pruning_rate if model_config.video_pruning_rate is not None else 0.0
262+
)
261263

262264
# Construct the vision projection.
263265
self.vit_hidden_size = config.vit_hidden_size
@@ -414,7 +416,7 @@ def apply_evs_per_video(
414416
video_embeds=reshaped_partial_mm_embed,
415417
video_size=(t, p * ih, iw),
416418
spatial_merge_size=self.spatial_merge_size,
417-
pruning_ratio=self.video_pruning_ratio,
419+
pruning_ratio=self.video_pruning_rate,
418420
flatten_output=False,
419421
).flatten(start_dim=1)
420422
# -> [num_frames, num_patches_per_frame*h*w]
@@ -437,7 +439,7 @@ def apply_evs(
437439
) -> Tuple[List[torch.Tensor], Optional[List[List[int] | None]]]:
438440
"""Apply EVS to the multimodal embedding."""
439441
# Skip EVS if pruning ratio is 0.
440-
if self.video_pruning_ratio <= 0:
442+
if self.video_pruning_rate <= 0:
441443
return mm_embedding, None
442444

443445
modality_types = [
@@ -448,7 +450,7 @@ def apply_evs(
448450
return mm_embedding, None
449451

450452
video_size_list = [
451-
multimodal_data[modality_type]["video_size"]
453+
multimodal_data[modality_type].get("video_size") if modality_type == "video" else None
452454
for modality_type, multimodal_data in zip(modality_types, multimodal_data_lst)
453455
]
454456
mm_embedding_evs = []
@@ -487,17 +489,23 @@ def forward(
487489
pixel_values_flat = data["pixel_values"]
488490
image_sizes = data["image_sizes"]
489491
embeds = self.extract_feature_dynamic(pixel_values_flat, image_sizes)
490-
mm_embedding.append(embeds.reshape(-1, self.llm_hidden_size))
492+
# Keep 3D shape for apply_evs, will reshape to 2D after EVS
493+
mm_embedding.append(embeds)
491494
# This applies to images without dynamic resolution, or videos.
492495
else:
493496
# Fallback to fixed-tile extraction for this modality.
494497
pixel_values = data["pixel_values"]
495498
embeds = self.extract_feature(pixel_values)
496-
mm_embedding.append(embeds.reshape(-1, self.llm_hidden_size))
499+
# Keep 3D shape [num_patches, h*w, hidden] for apply_evs
500+
mm_embedding.append(embeds)
497501

498-
return mm_embedding, [None] * len(modality_types)
502+
# Apply EVS if video_pruning_rate > 0
503+
mm_embedding, num_tokens_in_videos = self.apply_evs(mm_embedding, multimodal_data_lst)
504+
# Reshape to 2D after EVS: [num_patches*h*w, hidden_size]
505+
mm_embedding = [m.reshape(-1, self.llm_hidden_size) for m in mm_embedding]
506+
return mm_embedding, num_tokens_in_videos
499507

500-
# Existing fixed-tile path.
508+
# Existing fixed-tile path (unreachable, kept for reference).
501509
pixel_values = [
502510
multimodal_data[modality_type]["pixel_values"]
503511
for modality_type, multimodal_data in zip(modality_types, multimodal_data_lst)
@@ -530,6 +538,9 @@ def __init__(
530538
trust_remote_code: bool = True,
531539
**kwargs,
532540
):
541+
# Extract video_pruning_rate before passing kwargs to parent
542+
video_pruning_rate = kwargs.pop("video_pruning_rate", None) or 0.0
543+
533544
super().__init__(
534545
model_path=model_path,
535546
config=config,
@@ -563,7 +574,7 @@ def __init__(
563574
self.num_image_token = int(
564575
(self.image_size // self.patch_size) ** 2 * (self.downsample_ratio**2)
565576
)
566-
self.video_pruning_ratio = VIDEO_PRUNING_RATIO
577+
self.video_pruning_rate = video_pruning_rate
567578
self.img_context_token = self.config.img_context_token
568579
self.video_context_token = self.config.video_context_token
569580
self.img_start_token = self.config.img_start_token
@@ -747,15 +758,15 @@ def get_num_tokens_per_video(
747758
self,
748759
*,
749760
video: List[Image.Image],
750-
video_pruning_ratio: Optional[float] = None,
761+
video_pruning_rate: Optional[float] = None,
751762
**kwargs,
752763
):
753764
# Use VIDEO_PRUNING_RATIO if not explicitly provided
754-
if video_pruning_ratio is None:
755-
video_pruning_ratio = self.video_pruning_ratio
765+
if video_pruning_rate is None:
766+
video_pruning_rate = self.video_pruning_rate
756767

757768
num_frames = len(video)
758-
if video_pruning_ratio > 0:
769+
if video_pruning_rate > 0:
759770
num_tokens_per_frame = self.get_num_tokens_per_image(
760771
image=video[0],
761772
max_num_tiles=VIDEO_MAX_NUM_TILES,
@@ -767,7 +778,7 @@ def get_num_tokens_per_video(
767778
num_total_tokens = compute_retained_tokens_count(
768779
video_size=video_size,
769780
spatial_merge_size=self.spatial_merge_size,
770-
pruning_ratio=video_pruning_ratio,
781+
pruning_ratio=video_pruning_rate,
771782
)
772783
# Add special tokens for each frame.
773784
num_total_tokens += num_frames * len(self.get_mm_special_token_ids())
@@ -776,7 +787,7 @@ def get_num_tokens_per_video(
776787
num_total_tokens = sum(
777788
self.get_num_tokens_per_image(
778789
image=frame,
779-
video_pruning_ratio=None,
790+
video_pruning_rate=None,
780791
max_num_tiles=VIDEO_MAX_NUM_TILES,
781792
**kwargs,
782793
)
@@ -961,7 +972,7 @@ def _process_video_prompts(
961972
processed_query.extend(frame_prompts)
962973
# Video_context_token as placeholder,
963974
# it will be replaced with the real image_tokens_per_frames during model forward.
964-
if self.video_pruning_ratio > 0:
975+
if self.video_pruning_rate > 0:
965976
evs_query.append(split_text_prompt[video_index])
966977
evs_query.append("This is a video:\n")
967978
for frame_sep in frame_separators:
@@ -986,7 +997,7 @@ def _process_video_prompts(
986997
]
987998
input_ids = torch.cat(input_ids_lst, dim=1)
988999

989-
if self.video_pruning_ratio > 0:
1000+
if self.video_pruning_rate > 0:
9901001
evs_query.append(split_text_prompt[-1])
9911002
evs_ids = [
9921003
self.tokenizer.encode(
@@ -1009,11 +1020,11 @@ def _compute_token_numbers_per_video(self, video_size_lst: List[Tuple]) -> List[
10091020
img_height = video_size[2]
10101021
img_width = video_size[3]
10111022

1012-
if self.video_pruning_ratio > 0:
1023+
if self.video_pruning_rate > 0:
10131024
desired_num_tokens = compute_retained_tokens_count(
10141025
video_size=(num_frames, num_patches_per_frame * img_height, img_width),
10151026
spatial_merge_size=self.spatial_merge_size,
1016-
pruning_ratio=self.video_pruning_ratio,
1027+
pruning_ratio=self.video_pruning_rate,
10171028
)
10181029
# It is dummy tokens and will be adjusted in VisionEncoder after applied EVS.
10191030
# Need to know the length of the full input ids ahead,
@@ -1069,7 +1080,7 @@ def __call__(
10691080
# Store input_ids for image modality here when EVS is enabled,
10701081
# which will be used in merge_evs_mm_embeds later.
10711082
modality_data["evs_ids"] = (
1072-
input_ids[0].to(torch.int32) if self.video_pruning_ratio > 0 else None
1083+
input_ids[0].to(torch.int32) if self.video_pruning_rate > 0 else None
10731084
)
10741085
elif videos is not None:
10751086
modality_type = "video"
@@ -1249,7 +1260,10 @@ def __init__(self, model_config: ModelConfig):
12491260
self.sound_context_token_id = getattr(config, "sound_context_token_id", None)
12501261
self.post_config()
12511262
self.is_loaded = True
1252-
self.video_pruning_ratio = VIDEO_PRUNING_RATIO
1263+
# Use config value if explicitly set (EVS enabled), otherwise default to 0.0 (EVS disabled)
1264+
self.video_pruning_rate = (
1265+
model_config.video_pruning_rate if model_config.video_pruning_rate is not None else 0.0
1266+
)
12531267

12541268
def load_weights(self, weights):
12551269
# Load vision encoder weights.
@@ -1378,7 +1392,7 @@ def _encode_multimodal(
13781392
if modality_type in ("image", "video"):
13791393
embs, num_tokens = self.vision_encoder([param])
13801394
mm_embeddings.append(embs[0])
1381-
mm_num_tokens.append(num_tokens[0])
1395+
mm_num_tokens.append(num_tokens[0] if num_tokens is not None else None)
13821396
elif modality_type == "audio":
13831397
mm_embeddings.append(self._encode_audio(param))
13841398
mm_num_tokens.append(None)
@@ -1421,7 +1435,7 @@ def forward(
14211435
"the TLLM_MULTIMODAL_DISAGGREGATED environment variable, or set it to '0'."
14221436
)
14231437
# Adjust input_ids in videos if EVS is applied.
1424-
if self.video_pruning_ratio > 0:
1438+
if self.video_pruning_rate > 0:
14251439
input_ids = self.merge_evs_mm_embeds(
14261440
num_tokens_in_videos,
14271441
multimodal_params=multimodal_params[:num_context_requests],

tensorrt_llm/_torch/modules/mamba/fuse_elementwise_ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ def _extract_transpose_prefill_kernel(
4242
conv_mask = conv_offsets < conv_dim
4343
mask = seq_mask[:, None] & conv_mask[None, :]
4444

45-
src_offsets = seq_offsets[:, None] * d_in_proj + (d_inner + conv_offsets[None, :])
45+
# Cast to int64 to avoid overflow: seq_offsets * d_in_proj can exceed INT32_MAX
46+
# (e.g., 131071 * 22656 = 2,969,544,576 > 2,147,483,647)
47+
src_offsets = seq_offsets[:, None].to(tl.int64) * d_in_proj + d_inner + conv_offsets[None, :]
4648
data = tl.load(src_ptr + src_offsets, mask=mask, other=0.0)
4749

4850
dst_offsets = conv_offsets[:, None] * num_prefill_tokens + seq_offsets[None, :]

tensorrt_llm/_torch/modules/mamba/layernorm_gated.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ def _layer_norm_fwd_1pass_kernel(
5252
X += row * stride_x_row + group * N
5353
Y += row * stride_y_row + group * N
5454
if HAS_Z:
55-
Z += row * stride_z_row + group * N
55+
# Cast to int64 to avoid overflow: row * stride_z_row can exceed INT32_MAX
56+
# when Z is a non-contiguous slice (e.g., 131071 * 22656 = 2,969,544,576)
57+
Z += tl.cast(row, tl.int64) * stride_z_row + group * N
5658
if not IS_RMS_NORM:
5759
Mean += group * M
5860
Rstd += group * M

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,10 +203,15 @@ def __init__(
203203
self.attn_runtime_features = attn_runtime_features or AttentionRuntimeFeatures(
204204
)
205205

206+
input_processor_kwargs = {}
207+
if llm_args.video_pruning_rate is not None:
208+
input_processor_kwargs[
209+
'video_pruning_rate'] = llm_args.video_pruning_rate
206210
self.input_processor = create_input_processor(
207211
model_path,
208212
tokenizer=None,
209-
checkpoint_format=llm_args.checkpoint_format)
213+
checkpoint_format=llm_args.checkpoint_format,
214+
**input_processor_kwargs)
210215
self.input_processor_with_hash = create_input_processor_with_hash(
211216
self.input_processor)
212217
if model is None:

tensorrt_llm/_torch/pyexecutor/model_loader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,7 @@ def _load_and_validate_config(
485485
use_cute_dsl_blockscaling_mm,
486486
use_cute_dsl_blockscaling_bmm=self.llm_args.
487487
use_cute_dsl_blockscaling_bmm,
488+
video_pruning_rate=self.llm_args.video_pruning_rate,
488489
)
489490

490491
# Only pass model_kwargs if it's explicitly set (not None)

tensorrt_llm/commands/serve.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def get_llm_args(
154154
fail_fast_on_attention_window_too_large: bool = True,
155155
otlp_traces_endpoint: Optional[str] = None,
156156
enable_chunked_prefill: bool = False,
157+
video_pruning_rate: Optional[float] = None,
157158
**llm_args_extra_dict: Any):
158159

159160
if gpus_per_node is None:
@@ -236,6 +237,8 @@ def get_llm_args(
236237
otlp_traces_endpoint,
237238
"fail_fast_on_attention_window_too_large":
238239
fail_fast_on_attention_window_too_large,
240+
"video_pruning_rate":
241+
video_pruning_rate,
239242
}
240243

241244
llm_args = {
@@ -718,6 +721,14 @@ def convert(self, value: Any, param: Optional["click.Parameter"],
718721
default=None,
719722
help=help_info_with_stability_tag(
720723
"Keyword arguments for media I/O.", "prototype"))
724+
@click.option("--video_pruning_rate",
725+
type=float,
726+
default=None,
727+
help=help_info_with_stability_tag(
728+
"Pruning rate for video frames in multimodal models. "
729+
"Applied by Efficient Video Sampling (EVS). "
730+
"None disables EVS, values in [0, 1) enable pruning.",
731+
"prototype"))
721732
@click.option("--chat_template",
722733
type=str,
723734
default=None,
@@ -760,8 +771,9 @@ def serve(
760771
fail_fast_on_attention_window_too_large: bool,
761772
otlp_traces_endpoint: Optional[str], enable_chunked_prefill: bool,
762773
disagg_cluster_uri: Optional[str], media_io_kwargs: Optional[str],
763-
custom_module_dirs: list[Path], chat_template: Optional[str],
764-
grpc: bool, served_model_name: Optional[str],
774+
video_pruning_rate: Optional[float], custom_module_dirs: list[Path],
775+
chat_template: Optional[str], grpc: bool,
776+
served_model_name: Optional[str],
765777
extra_visual_gen_options: Optional[str]):
766778
"""Running an OpenAI API compatible server
767779
@@ -815,7 +827,8 @@ def _serve_llm():
815827
fail_fast_on_attention_window_too_large=
816828
fail_fast_on_attention_window_too_large,
817829
otlp_traces_endpoint=otlp_traces_endpoint,
818-
enable_chunked_prefill=enable_chunked_prefill)
830+
enable_chunked_prefill=enable_chunked_prefill,
831+
video_pruning_rate=video_pruning_rate)
819832

820833
llm_args_extra_dict = {}
821834
if extra_llm_api_options is not None:

0 commit comments

Comments
 (0)