Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# are not supported by Machete yet.

# marlin arches for fp16 output
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX" "${CUDA_ARCHS}")
# Include 9.0 so H100/H200 get native SASS (see MARLIN_OTHER_ARCHS comment)
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
# marlin has limited support for turing
cuda_archs_loose_intersection(MARLIN_SM75_ARCHS "7.5" "${CUDA_ARCHS}")
# marlin arches for bf16 output (we need 9.0 for bf16 atomicAdd PTX)
Expand All @@ -362,7 +363,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
cuda_archs_loose_intersection(MARLIN_FP8_ARCHS "8.9;12.0;12.1" "${CUDA_ARCHS}")
# marlin arches for other files
cuda_archs_loose_intersection(MARLIN_OTHER_ARCHS "7.5;8.0+PTX" "${CUDA_ARCHS}")
# Include 9.0 so that H100/H200 (sm_90) get native SASS instead of relying
# on PTX JIT, which fails when the wheel's CUDA toolkit is newer than the
# driver (e.g. wheel built with CTK 12.9 on a CUDA 12.8 driver).
cuda_archs_loose_intersection(MARLIN_OTHER_ARCHS "7.5;8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")

if (MARLIN_OTHER_ARCHS)

Expand Down Expand Up @@ -1057,7 +1061,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# moe marlin arches
# note that we always set `use_atomic_add=False` for moe marlin now,
# so we don't need 9.0 for bf16 atomicAdd PTX
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX" "${CUDA_ARCHS}")
# Include 9.0 so H100/H200 get native SASS (see MARLIN_MOE_OTHER_ARCHS comment)
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
# moe marlin has limited support for turing
cuda_archs_loose_intersection(MARLIN_MOE_SM75_ARCHS "7.5" "${CUDA_ARCHS}")
# moe marlin arches for fp8 input
Expand All @@ -1066,7 +1071,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
cuda_archs_loose_intersection(MARLIN_MOE_FP8_ARCHS "8.9;12.0;12.1" "${CUDA_ARCHS}")
# moe marlin arches for other files
cuda_archs_loose_intersection(MARLIN_MOE_OTHER_ARCHS "7.5;8.0+PTX" "${CUDA_ARCHS}")
# Include 9.0 so that H100/H200 (sm_90) get native SASS instead of relying
# on PTX JIT, which fails when the wheel's CUDA toolkit is newer than the
# driver (e.g. wheel built with CTK 12.9 on a CUDA 12.8 driver).
cuda_archs_loose_intersection(MARLIN_MOE_OTHER_ARCHS "7.5;8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
if (MARLIN_MOE_OTHER_ARCHS)

#
Expand Down
72 changes: 60 additions & 12 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1175,9 +1175,14 @@ def gptq_marlin_repack(
num_bits: int,
is_a_8bit: bool = False,
) -> torch.Tensor:
return torch.ops._C.gptq_marlin_repack(
b_q_weight, perm, size_k, size_n, num_bits, is_a_8bit
)
try:
return torch.ops._C.gptq_marlin_repack(
b_q_weight, perm, size_k, size_n, num_bits, is_a_8bit
)
except RuntimeError as err:
if "unsupported toolchain" in str(err):
_raise_ptx_error(err)
raise


if hasattr(torch.ops._C, "gptq_marlin_repack"):
Expand Down Expand Up @@ -1208,9 +1213,14 @@ def awq_marlin_repack(
num_bits: int,
is_a_8bit: bool = False,
) -> torch.Tensor:
return torch.ops._C.awq_marlin_repack(
b_q_weight, size_k, size_n, num_bits, is_a_8bit
)
try:
return torch.ops._C.awq_marlin_repack(
b_q_weight, size_k, size_n, num_bits, is_a_8bit
)
except RuntimeError as err:
if "unsupported toolchain" in str(err):
_raise_ptx_error(err)
raise


if hasattr(torch.ops._C, "awq_marlin_repack"):
Expand All @@ -1232,6 +1242,34 @@ def _awq_marlin_repack_fake(
)


def _raise_ptx_error(original_error: Exception) -> None:
"""Raise an informative error when a Marlin kernel fails due to PTX
incompatibility (e.g. pre-built wheel compiled with CUDA 12.9+ running
on a system with a CUDA 12.8 driver)."""
import subprocess
cuda_ver = "unknown"
try:
result = subprocess.run(["nvidia-smi",
"--query-gpu=driver_version",
"--format=csv,noheader"],
capture_output=True, text=True, timeout=5)
if result.returncode == 0:
cuda_ver = result.stdout.strip().split("\n")[0]
except Exception:
pass
raise RuntimeError(
"Marlin repack kernel failed with a CUDA error that usually "
"indicates a PTX version mismatch: the pre-built vLLM wheel was "
"compiled with a CUDA toolkit newer than what your GPU driver "
f"supports (driver version: {cuda_ver}).\n\n"
"To fix this, build vLLM from source with your system's CUDA "
"toolkit:\n"
" pip install vllm --no-binary vllm\n\n"
"Or install a matching CUDA toolkit / update your GPU driver.\n"
"See https://github.com/vllm-project/vllm/issues/38619"
) from original_error
Comment on lines +1260 to +1270
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The error message is currently specific to "MoE repack", but this helper function is also used by standard Marlin repack operations (gptq_marlin_repack and awq_marlin_repack). A more generic message would be more accurate for all callers.

Suggested change
raise RuntimeError(
"Marlin MoE repack kernel failed with a CUDA error that usually "
"indicates a PTX version mismatch: the pre-built vLLM wheel was "
"compiled with a CUDA toolkit newer than what your GPU driver "
f"supports (driver version: {cuda_ver}).\n\n"
"To fix this, build vLLM from source with your system's CUDA "
"toolkit:\n"
" pip install vllm --no-binary vllm\n\n"
"Or install a matching CUDA toolkit / update your GPU driver.\n"
"See https://github.com/vllm-project/vllm/issues/38619"
) from original_error
raise RuntimeError(
"Marlin repack kernel failed with a CUDA error that usually "
"indicates a PTX version mismatch: the pre-built vLLM wheel was "
"compiled with a CUDA toolkit newer than what your GPU driver "
f"supports (driver version: {cuda_ver}).\n\n"
"To fix this, build vLLM from source with your system's CUDA "
"toolkit:\n"
" pip install vllm --no-binary vllm\n\n"
"Or install a matching CUDA toolkit / update your GPU driver.\n"
"See https://github.com/vllm-project/vllm/issues/38619"
) from original_error



def gptq_marlin_moe_repack(
b_q_weight: torch.Tensor,
perm: torch.Tensor,
Expand All @@ -1248,9 +1286,14 @@ def gptq_marlin_moe_repack(
dtype=b_q_weight.dtype,
)
for e in range(num_experts):
output[e] = torch.ops._C.gptq_marlin_repack(
b_q_weight[e], perm[e], size_k, size_n, num_bits, is_a_8bit
)
try:
output[e] = torch.ops._C.gptq_marlin_repack(
b_q_weight[e], perm[e], size_k, size_n, num_bits, is_a_8bit
)
except RuntimeError as err:
if "unsupported toolchain" in str(err):
_raise_ptx_error(err)
raise
return output


Expand All @@ -1270,9 +1313,14 @@ def awq_marlin_moe_repack(
dtype=b_q_weight.dtype,
)
for e in range(num_experts):
output[e] = torch.ops._C.awq_marlin_repack(
b_q_weight[e], size_k, size_n, num_bits, is_a_8bit
)
try:
output[e] = torch.ops._C.awq_marlin_repack(
b_q_weight[e], size_k, size_n, num_bits, is_a_8bit
)
except RuntimeError as err:
if "unsupported toolchain" in str(err):
_raise_ptx_error(err)
raise
return output


Expand Down
Loading