Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion vllm/compilation/passes/fusion/act_quant_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@
if silu_and_mul_nvfp4_quant_supported:
FUSED_OPS[kNvfp4Dynamic] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501

if current_platform.is_cuda_alike():
# Check if the per-block quant operation is available (newer ROCm/CUDA versions)
if current_platform.is_cuda_alike() and hasattr(
torch.ops._C, "silu_and_mul_per_block_quant"

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this related to the NPU work?

):
FUSED_OPS[kFp8Dynamic128Sym] = torch.ops._C.silu_and_mul_per_block_quant.default
FUSED_OPS[kFp8Dynamic64Sym] = torch.ops._C.silu_and_mul_per_block_quant.default

Expand Down
15 changes: 15 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,11 @@
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480
VLLM_MORIIO_CONNECTOR_READ_MODE: bool = False
VLLM_VISION_NPU_BACKEND: str = ""
VLLM_VISION_NPU_CACHE: str | None = None
VLLM_VISION_NPU_DEVICE: str | None = None
VLLM_NPU_ASYNC_PIPELINE: bool = False
VLLM_NPU_TIMING: bool = False
VLLM_MORIIO_QP_PER_TRANSFER: int = 1
VLLM_MORIIO_POST_BATCH_SIZE: int = -1
VLLM_MORIIO_NUM_WORKERS: int = 1
Expand Down Expand Up @@ -1744,6 +1749,16 @@ def _get_or_set_default() -> str:
# Disable PDL for LoRA, as enabling PDL with LoRA on SM100 causes
# Triton compilation to fail.
"VLLM_LORA_DISABLE_PDL": lambda: bool(int(os.getenv("VLLM_LORA_DISABLE_PDL", "0"))),
# NPU vision backend to use (e.g., "flexmlrt" for FlexMLRT backend)
"VLLM_VISION_NPU_BACKEND": lambda: os.getenv("VLLM_VISION_NPU_BACKEND", ""),

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is only one backend right? I.e. we got drop this env var?

# Path to NPU model cache directory (required for FlexMLRT backend)
"VLLM_VISION_NPU_CACHE": lambda: os.getenv("VLLM_VISION_NPU_CACHE"),
# NPU device name (e.g., "stx" for Strix, "phx" for Phoenix)
"VLLM_VISION_NPU_DEVICE": lambda: os.getenv("VLLM_VISION_NPU_DEVICE"),
# Enable async pipelining of NPU vision encoding with GPU LLM inference
"VLLM_NPU_ASYNC_PIPELINE": lambda: os.getenv("VLLM_NPU_ASYNC_PIPELINE", "0") == "1",
# Enable NPU timing debug logs
"VLLM_NPU_TIMING": lambda: os.getenv("VLLM_NPU_TIMING", "0") == "1",
# Enable CUDA compatibility mode for datacenter GPUs with older
# driver versions than the CUDA toolkit major version of vLLM.
"VLLM_ENABLE_CUDA_COMPATIBILITY": lambda: (
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def __init__(
else:
self.norm = PPMissingLayer()

def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor, **kwargs) -> torch.Tensor:
return self.embed_tokens(input_ids)

def forward(
Expand Down
194 changes: 188 additions & 6 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,18 +581,40 @@ def __init__(
) -> None:
super().__init__()

# Store minimal config needed for both NPU and PyTorch paths

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder whether making a new Qwen2_5_VisionTransformerNPU class would be cleaner instead of doing the overwriting/conditionals here. Can you try that?

self.out_hidden_size = vision_config.out_hidden_size
self.spatial_merge_size = vision_config.spatial_merge_size
self.spatial_merge_unit = self.spatial_merge_size**2

# Check NPU backend before creating PyTorch modules
from vllm.model_executor.models.vision import (
get_npu_vision_backend,
use_npu_vision_backend,
)

if use_npu_vision_backend():
try:
self.npu_backend = get_npu_vision_backend()
logger.info("[Qwen2.5VL] Using NPU vision backend")
return
except Exception as e:
logger.error("[Qwen2.5VL] NPU backend init failed: %s", e)
raise RuntimeError(
f"NPU vision backend initialization failed: {e}. "
"Set VLLM_VISION_NPU_BACKEND='' to use PyTorch backend."
) from e

self.npu_backend = None
patch_size = vision_config.patch_size
temporal_patch_size = vision_config.temporal_patch_size
in_channels = vision_config.in_channels
depth = vision_config.depth
self.hidden_size = vision_config.hidden_size
self.num_heads = vision_config.num_heads
self.out_hidden_size = vision_config.out_hidden_size

# args for get_window_index_thw
self.window_size = vision_config.window_size
self.patch_size = vision_config.patch_size
self.spatial_merge_size = vision_config.spatial_merge_size
self.fullatt_block_indexes = vision_config.fullatt_block_indexes
self.spatial_merge_unit = self.spatial_merge_size**2
self.patch_embed = Qwen2_5_VisionPatchEmbed(
Expand Down Expand Up @@ -653,11 +675,22 @@ def __init__(

@property
def dtype(self) -> torch.dtype:
return self.patch_embed.proj.weight.dtype
if hasattr(self, "npu_backend") and self.npu_backend is not None:
return torch.bfloat16
if hasattr(self, "patch_embed"):
return self.patch_embed.proj.weight.dtype
# Safe fallback if neither exists
return torch.bfloat16

@property
def device(self) -> torch.device:
return self.patch_embed.proj.weight.device
if hasattr(self, "npu_backend") and self.npu_backend is not None:
# NPU outputs are on CPU, transfer to GPU happens in forward
return torch.device("cpu")
if hasattr(self, "patch_embed"):
return self.patch_embed.proj.weight.device
# Safe fallback
return torch.device("cpu")

def rotary_pos_emb_thw(self, t, h, w):
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
Expand Down Expand Up @@ -787,6 +820,94 @@ def forward(
x: torch.Tensor,
grid_thw: list[list[int]],
) -> torch.Tensor:
# Dispatch to NPU or PyTorch backend
if hasattr(self, "npu_backend") and self.npu_backend is not None:
return self._forward_npu(x, grid_thw)
else:
return self._forward_pytorch(x, grid_thw)

def _forward_npu(
self, pixel_values: torch.Tensor, grid_thw: list[list[int]]
) -> torch.Tensor:
"""Forward pass using NPU backend."""
import logging
import time

import numpy as np

logger = logging.getLogger(__name__)

# Convert PyTorch → NumPy (handle bfloat16 by converting to float32 first)
if pixel_values.dtype == torch.bfloat16:
pixel_values_np = pixel_values.cpu().float().numpy()
else:
pixel_values_np = pixel_values.cpu().numpy().astype(np.float32)
grid_thw_np = np.array(grid_thw, dtype=np.int64)

# Run NPU inference
embeddings_np = self.npu_backend.forward(pixel_values_np, grid_thw_np)

# Convert back to PyTorch and transfer to GPU for LLM
import vllm.envs as envs

if envs.VLLM_NPU_TIMING:
gpu_transfer_start = time.monotonic()
embeddings = torch.from_numpy(embeddings_np).to(
device="cuda", dtype=torch.bfloat16
)
gpu_transfer_ms = (time.monotonic() - gpu_transfer_start) * 1000
logger.debug(
"[NPU Timing] CPU→GPU transfer: %.2fms (%.2f MB)",
gpu_transfer_ms,
embeddings_np.nbytes / 1024**2,
)
logger.debug("[Vision→LLM] Vision embeddings shape: %s", embeddings.shape)
else:
embeddings = torch.from_numpy(embeddings_np).to(
device="cuda", dtype=torch.bfloat16
)

# NPU model outputs compressed tokens but vLLM expects uncompressed
# count. We need to pad/repeat to match expected count based on grid_thw
actual_tokens = embeddings.shape[0]
merge_size = self.spatial_merge_size
expected_tokens_per_image = [
(t * h * w) // (merge_size * merge_size) for t, h, w in grid_thw
]
total_expected = sum(expected_tokens_per_image)

if actual_tokens != total_expected:
logger.warning(
"[NPU] Token count mismatch: NPU output %s tokens, "
"but vLLM expects %s based on grid_thw. "
"Repeating tokens to match expected count.",
actual_tokens,
total_expected,
)
repeat_factor = total_expected / actual_tokens
if repeat_factor == int(repeat_factor):
embeddings = embeddings.repeat_interleave(int(repeat_factor), dim=0)
else:
embeddings = embeddings.unsqueeze(0).unsqueeze(0)
embeddings = torch.nn.functional.interpolate(
embeddings,
size=(total_expected, embeddings.shape[-1]),
mode="nearest",
)
embeddings = embeddings.squeeze(0).squeeze(0)

logger.debug(
"[NPU] Padded from %s to %s tokens", actual_tokens, embeddings.shape[0]
)

return embeddings

def _forward_pytorch(
self,
x: torch.Tensor,
grid_thw: list[list[int]],
) -> torch.Tensor:
"""Original PyTorch forward pass."""
# patchify
seq_len, _ = x.size()
rotary_pos_emb_cos = []
Expand Down Expand Up @@ -889,6 +1010,12 @@ def forward(
return hidden_states

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
if self.npu_backend is not None:
logger.info(
"[Qwen2.5VL Vision] Skipping weight loading (using NPU backend)"
)
return set()

stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("attn.qkv.", "attn.q.", "q"),
Expand Down Expand Up @@ -1231,8 +1358,25 @@ def _process_image_input(
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)

# Split concatenated embeddings for each image item.
merge_size = self.visual.spatial_merge_size
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
# When using NPU backend, merge is already done in NPU, so use actual
# output size
if hasattr(self.visual, "npu_backend") and self.visual.npu_backend is not None:
# NPU backend already did spatial merging - use actual output sizes
# For single image: sizes = [actual_num_tokens]
# For batched images: split based on actual output
num_images = len(grid_thw_list)
if num_images == 1:
# Single image - return the whole embedding
sizes = [image_embeds.shape[0]]
else:
# Multiple images - need to split based on actual grid sizes
# Each image: (T*H*W) // (merge_size^2) tokens after NPU
merge_size = self.visual.spatial_merge_size
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
else:
# PyTorch backend - calculate expected size
merge_size = self.visual.spatial_merge_size
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
return image_embeds.split(sizes)

def _postprocess_image_embeds_evs(
Expand Down Expand Up @@ -1495,6 +1639,22 @@ def compute_logits(
return self.language_model.compute_logits(hidden_states)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
if hasattr(self.visual, "npu_backend") and self.visual.npu_backend is not None:
logger.info(
"[Qwen2.5VL Model] Filtering out visual weights (using NPU backend)"
)
filtered_weights = []
visual_weight_count = 0
for name, weight in weights:
if name.startswith("visual."):
visual_weight_count += 1
continue
filtered_weights.append((name, weight))
logger.info(
"[Qwen2.5VL Model] Skipped %s visual weights", visual_weight_count
)
weights = filtered_weights

loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

Expand Down Expand Up @@ -1526,3 +1686,25 @@ def get_num_mm_connector_tokens(
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_vision_tokens // merge_size**2

def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: tuple[torch.Tensor, ...] | None = None,
is_multimodal: torch.Tensor | None = None,
) -> torch.Tensor:
"""Embed token ids and merge multimodal embeddings (V1 MM path)."""
inputs_embeds = self.language_model.model.embed_input_ids(input_ids)
if (
multimodal_embeddings is not None
and is_multimodal is not None
and len(multimodal_embeddings) > 0
):
from vllm.model_executor.models.utils import _merge_multimodal_embeddings

inputs_embeds = _merge_multimodal_embeddings(
inputs_embeds,
multimodal_embeddings,
is_multimodal,
)
return inputs_embeds
58 changes: 58 additions & 0 deletions vllm/model_executor/models/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,3 +601,61 @@ def get_llm_pos_ids_for_vision(
llm_pos_ids_list.append(_llm_pos_ids + start_idx)
llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1)
return llm_pos_ids


# ---------------------------------------------------------------------------
# NPU Vision Backend Support
# ---------------------------------------------------------------------------


def use_npu_vision_backend() -> bool:
"""Check if NPU backend is enabled for vision processing.

Returns:
True if VLLM_VISION_NPU_BACKEND environment variable is set to
a supported backend (flexmlrt), False otherwise.
"""
import vllm.envs as envs

backend = (
envs.VLLM_VISION_NPU_BACKEND.lower() if envs.VLLM_VISION_NPU_BACKEND else ""
)
return backend == "flexmlrt"


def get_npu_vision_backend():
"""Get NPU vision backend instance if enabled.

Returns:
NPUVisionBackend instance if NPU backend is enabled, None otherwise.
Returns AsyncFlexMLRTVisionBackend if VLLM_NPU_ASYNC_PIPELINE=1.

Raises:
ValueError: If backend name is recognized but initialization fails.
ImportError: If backend dependencies are not available.
"""
import vllm.envs as envs

backend_name = (
envs.VLLM_VISION_NPU_BACKEND.lower() if envs.VLLM_VISION_NPU_BACKEND else ""
)

if backend_name == "flexmlrt":
model_cache = envs.VLLM_VISION_NPU_CACHE
if not model_cache:
raise ValueError(
"VLLM_VISION_NPU_CACHE must be set when using FlexMLRT backend"
)
device_name = envs.VLLM_VISION_NPU_DEVICE or "stx"

# Use async backend if pipelining is enabled
if envs.VLLM_NPU_ASYNC_PIPELINE:
from vllm.vision_npu.flexmlrt_backend import AsyncFlexMLRTVisionBackend

return AsyncFlexMLRTVisionBackend(model_cache, device_name)
else:
from vllm.vision_npu.flexmlrt_backend import FlexMLRTVisionBackend

return FlexMLRTVisionBackend(model_cache, device_name)

return None
Loading
Loading