Fix Marlin repack PTX incompatibility on H100/H200 (CUDA 12.8)#38669
Fix Marlin repack PTX incompatibility on H100/H200 (CUDA 12.8)#38669DavidBellamy wants to merge 2 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request updates the build configuration to include native support for sm_90 in Marlin kernels and introduces a detailed error handler for PTX version mismatches during repack operations. Review feedback highlights that removing the +PTX suffix from sm_80 in CMakeLists.txt breaks compatibility for several architectures (like sm_86 and sm_89) and recommends ensuring both 8.0 and 9.0 retain the suffix. Additionally, the error message in vllm/_custom_ops.py should be generalized as it currently incorrectly specifies "MoE repack" for standard repack callers.
CMakeLists.txt
Outdated
| # Include 9.0 so that H100/H200 (sm_90) get native SASS instead of relying | ||
| # on PTX JIT, which fails when the wheel's CUDA toolkit is newer than the | ||
| # driver (e.g. wheel built with CTK 12.9 on a CUDA 12.8 driver). | ||
| cuda_archs_loose_intersection(MARLIN_OTHER_ARCHS "7.5;8.0;9.0+PTX" "${CUDA_ARCHS}") |
There was a problem hiding this comment.
Removing the +PTX suffix from 8.0 breaks compatibility for architectures like sm_86 (Ampere) and sm_89 (Ada) that are not explicitly listed in the Marlin arch strings. These devices rely on JIT-compiling from the sm_80 PTX because sm_80 SASS is not binary-compatible with them. To support sm_90 natively while maintaining compatibility for other sm_8x devices, both 8.0 and 9.0 should include the +PTX suffix.
Additionally, please consider applying this same change to MARLIN_ARCHS (line 354) and MARLIN_MOE_ARCHS (line 1063), as the main GEMM kernels will otherwise still trigger the same PTX JIT incompatibility on Hopper GPUs.
cuda_archs_loose_intersection(MARLIN_OTHER_ARCHS "7.5;8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
CMakeLists.txt
Outdated
| # Include 9.0 so that H100/H200 (sm_90) get native SASS instead of relying | ||
| # on PTX JIT, which fails when the wheel's CUDA toolkit is newer than the | ||
| # driver (e.g. wheel built with CTK 12.9 on a CUDA 12.8 driver). | ||
| cuda_archs_loose_intersection(MARLIN_MOE_OTHER_ARCHS "7.5;8.0;9.0+PTX" "${CUDA_ARCHS}") |
There was a problem hiding this comment.
| raise RuntimeError( | ||
| "Marlin MoE repack kernel failed with a CUDA error that usually " | ||
| "indicates a PTX version mismatch: the pre-built vLLM wheel was " | ||
| "compiled with a CUDA toolkit newer than what your GPU driver " | ||
| f"supports (driver version: {cuda_ver}).\n\n" | ||
| "To fix this, build vLLM from source with your system's CUDA " | ||
| "toolkit:\n" | ||
| " pip install vllm --no-binary vllm\n\n" | ||
| "Or install a matching CUDA toolkit / update your GPU driver.\n" | ||
| "See https://github.com/vllm-project/vllm/issues/38619" | ||
| ) from original_error |
There was a problem hiding this comment.
The error message is currently specific to "MoE repack", but this helper function is also used by standard Marlin repack operations (gptq_marlin_repack and awq_marlin_repack). A more generic message would be more accurate for all callers.
| raise RuntimeError( | |
| "Marlin MoE repack kernel failed with a CUDA error that usually " | |
| "indicates a PTX version mismatch: the pre-built vLLM wheel was " | |
| "compiled with a CUDA toolkit newer than what your GPU driver " | |
| f"supports (driver version: {cuda_ver}).\n\n" | |
| "To fix this, build vLLM from source with your system's CUDA " | |
| "toolkit:\n" | |
| " pip install vllm --no-binary vllm\n\n" | |
| "Or install a matching CUDA toolkit / update your GPU driver.\n" | |
| "See https://github.com/vllm-project/vllm/issues/38619" | |
| ) from original_error | |
| raise RuntimeError( | |
| "Marlin repack kernel failed with a CUDA error that usually " | |
| "indicates a PTX version mismatch: the pre-built vLLM wheel was " | |
| "compiled with a CUDA toolkit newer than what your GPU driver " | |
| f"supports (driver version: {cuda_ver}).\n\n" | |
| "To fix this, build vLLM from source with your system's CUDA " | |
| "toolkit:\n" | |
| " pip install vllm --no-binary vllm\n\n" | |
| "Or install a matching CUDA toolkit / update your GPU driver.\n" | |
| "See https://github.com/vllm-project/vllm/issues/38619" | |
| ) from original_error |
Add sm_90 to MARLIN_OTHER_ARCHS and MARLIN_MOE_OTHER_ARCHS so that Marlin repack kernels (gptq_marlin_repack, awq_marlin_repack) compile native SASS for H100/H200 instead of relying on PTX JIT. When a pre-built wheel is compiled with a newer CUDA toolkit than the driver supports (e.g. CTK 12.9 wheel on a 12.8 driver), PTX JIT fails with "the provided PTX was compiled with an unsupported toolchain." Also wrap all Marlin repack call sites with a try/except that catches the PTX toolchain error and raises a clear diagnostic message with the driver version and build-from-source instructions. Fixes vllm-project#38619 Signed-off-by: David Bellamy <12414531+DavidBellamy@users.noreply.github.com>
e53adee to
3d72856
Compare
- Keep 8.0+PTX (not bare 8.0) so sm_86/sm_89 can still JIT from PTX - Add 9.0+PTX to MARLIN_ARCHS and MARLIN_MOE_ARCHS (main GEMM kernels) to avoid the same PTX JIT issue on the inference path - Generalize error message from "MoE repack" to "repack" since the helper is shared by all four repack functions Signed-off-by: David Bellamy <12414531+DavidBellamy@users.noreply.github.com>
Summary
Fixes #38619. The Marlin MoE repack kernel (
gptq_marlin_moe_repack) crashes withCUDA error: the provided PTX was compiled with an unsupported toolchainwhen serving quantized MoE models (e.g. Kimi K2.5) on H100/H200 with a CUDA 12.8 driver, because pre-built wheels compiled with a newer CUDA toolkit generate PTX that the 12.8 driver cannot JIT-compile.Root cause:
MARLIN_OTHER_ARCHSandMARLIN_MOE_OTHER_ARCHSinCMakeLists.txtwere set to"7.5;8.0+PTX", meaning on sm_90 (H100/H200) the driver must JIT-compile sm_80 PTX at runtime. If the wheel was built with CTK 12.9+, the embedded PTX uses a newer ISA version than the 12.8 driver supports.Changes:
9.0to bothMARLIN_OTHER_ARCHSandMARLIN_MOE_OTHER_ARCHS("7.5;8.0;9.0+PTX"), so H100/H200 get native sm_90 SASS for Marlin repack kernels. The+PTXmoves to 9.0 to preserve forward compatibility for future architectures.gptq_marlin_repack,awq_marlin_repack, and their MoE variants) with try/except that catches the "unsupported toolchain" CUDA error and raises a diagnostic message including the driver version and build-from-source instructions.Testing
Validated on an M2 cluster node:
CUDA_HOME=/usr/local/cuda-12.8, PyTorch 2.10.0+cu128moonshotai/Kimi-K2.5(1T params, compressed-tensors WNA16 INT4, 384 MoE experts)--enforce-eager,--max-model-len 32768process_weights_after_loadingwith the PTX toolchain error.