Skip to content

Commit 2550f94

Browse files
author
Vinayyyy7
committed
Sync with main: Add fix_vllm_pdl_blackwell
1 parent 9b553b5 commit 2550f94

File tree

1 file changed

+115
-0
lines changed

1 file changed

+115
-0
lines changed

unsloth/import_fixes.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,3 +549,118 @@ def fix_huggingface_hub():
549549
huggingface_hub.is_offline_mode = (
550550
lambda: huggingface_hub.constants.HF_HUB_OFFLINE
551551
)
552+
553+
554+
def fix_vllm_pdl_blackwell():
555+
"""
556+
Fix vLLM PDL (Programmatic Dependent Launch) bug on Blackwell GPUs (SM100).
557+
558+
The issue: vLLM's LoRA Triton kernels use tl.extra.cuda.gdc_wait() for PDL
559+
optimization on SM90+ GPUs. This fails on SM100 (B200/B100) during CUDA graph
560+
capture because Triton's pipeliner can't handle gdc_wait in complex kernels.
561+
562+
See: https://github.com/vllm-project/vllm/issues/30872
563+
"""
564+
if importlib.util.find_spec("vllm") is None:
565+
return
566+
567+
# Check if any CUDA GPU is SM100 (Blackwell)
568+
try:
569+
import torch
570+
571+
if not torch.cuda.is_available():
572+
return
573+
574+
# Scan all GPUs for SM100 - fix applies globally via env var and monkey-patch
575+
has_sm100 = False
576+
sm100_gpu_name = None
577+
for i in range(torch.cuda.device_count()):
578+
major, minor = torch.cuda.get_device_capability(i)
579+
if major == 10:
580+
has_sm100 = True
581+
sm100_gpu_name = torch.cuda.get_device_name(i)
582+
break
583+
584+
if not has_sm100:
585+
return
586+
except Exception:
587+
return
588+
589+
# Helper to check if module spec exists
590+
def _spec_exists(name):
591+
try:
592+
return importlib.util.find_spec(name) is not None
593+
except (ModuleNotFoundError, ValueError):
594+
return False
595+
596+
# Check if vLLM has the PDL-related modules before doing internet check
597+
has_utils = _spec_exists("vllm.lora.ops.triton_ops.utils")
598+
has_expand_op = _spec_exists("vllm.lora.ops.triton_ops.lora_expand_op")
599+
has_shrink_op = _spec_exists("vllm.lora.ops.triton_ops.lora_shrink_op")
600+
601+
if not has_utils and not has_expand_op and not has_shrink_op:
602+
# Old vLLM version without PDL support - nothing to patch
603+
return
604+
605+
# Check if vLLM version includes the fix
606+
VLLM_PDL_FIX_VERSION = "0.13.2"
607+
try:
608+
vllm_version = Version(importlib_version("vllm"))
609+
if vllm_version > Version(VLLM_PDL_FIX_VERSION):
610+
logger.info(
611+
f"Unsloth: SM100 ({sm100_gpu_name}) detected but vLLM {vllm_version} "
612+
f"should include PDL fix - skipping workaround"
613+
)
614+
return
615+
except Exception as e:
616+
logger.debug(
617+
f"Unsloth: vLLM version check failed ({e}), applying PDL workaround."
618+
)
619+
620+
# Apply the PDL fix
621+
os.environ["TRITON_DISABLE_PDL"] = "1"
622+
623+
def fake_supports_pdl(*args, **kwargs):
624+
return False
625+
626+
patched = []
627+
628+
# First, patch the source module (utils.py) where supports_pdl is defined.
629+
# This is critical because supports_pdl uses @lru_cache - we must clear the
630+
# cache to prevent stale cached results from the original function.
631+
try:
632+
utils_module = importlib.import_module("vllm.lora.ops.triton_ops.utils")
633+
if hasattr(utils_module, "supports_pdl"):
634+
original_fn = utils_module.supports_pdl
635+
if hasattr(original_fn, "cache_clear"):
636+
original_fn.cache_clear()
637+
utils_module.supports_pdl = fake_supports_pdl
638+
patched.append("utils")
639+
except (ImportError, ModuleNotFoundError, AttributeError):
640+
pass
641+
642+
# Also patch the consumer modules that import supports_pdl from utils.
643+
# This ensures the patched function is used even if the module was already
644+
# imported before this fix runs.
645+
consumer_modules = {
646+
"lora_expand_op": "vllm.lora.ops.triton_ops.lora_expand_op",
647+
"lora_shrink_op": "vllm.lora.ops.triton_ops.lora_shrink_op",
648+
"fused_moe_lora_op": "vllm.lora.ops.triton_ops.fused_moe_lora_op",
649+
}
650+
for name, path in consumer_modules.items():
651+
try:
652+
module = importlib.import_module(path)
653+
if hasattr(module, "supports_pdl"):
654+
module.supports_pdl = fake_supports_pdl
655+
patched.append(name)
656+
except (ImportError, ModuleNotFoundError, AttributeError):
657+
pass
658+
659+
if patched:
660+
logger.info(
661+
f"Unsloth: Applied PDL fix for SM100 ({sm100_gpu_name}) - "
662+
f"patched: {', '.join(patched)}"
663+
)
664+
else:
665+
# Just set the env var - vLLM might be an older version without supports_pdl
666+
logger.info(f"Unsloth: Set TRITON_DISABLE_PDL=1 for SM100 ({sm100_gpu_name})")

0 commit comments

Comments
 (0)