@@ -549,3 +549,118 @@ def fix_huggingface_hub():
549549 huggingface_hub .is_offline_mode = (
550550 lambda : huggingface_hub .constants .HF_HUB_OFFLINE
551551 )
552+
553+
554+ def fix_vllm_pdl_blackwell ():
555+ """
556+ Fix vLLM PDL (Programmatic Dependent Launch) bug on Blackwell GPUs (SM100).
557+
558+ The issue: vLLM's LoRA Triton kernels use tl.extra.cuda.gdc_wait() for PDL
559+ optimization on SM90+ GPUs. This fails on SM100 (B200/B100) during CUDA graph
560+ capture because Triton's pipeliner can't handle gdc_wait in complex kernels.
561+
562+ See: https://github.com/vllm-project/vllm/issues/30872
563+ """
564+ if importlib .util .find_spec ("vllm" ) is None :
565+ return
566+
567+ # Check if any CUDA GPU is SM100 (Blackwell)
568+ try :
569+ import torch
570+
571+ if not torch .cuda .is_available ():
572+ return
573+
574+ # Scan all GPUs for SM100 - fix applies globally via env var and monkey-patch
575+ has_sm100 = False
576+ sm100_gpu_name = None
577+ for i in range (torch .cuda .device_count ()):
578+ major , minor = torch .cuda .get_device_capability (i )
579+ if major == 10 :
580+ has_sm100 = True
581+ sm100_gpu_name = torch .cuda .get_device_name (i )
582+ break
583+
584+ if not has_sm100 :
585+ return
586+ except Exception :
587+ return
588+
589+ # Helper to check if module spec exists
590+ def _spec_exists (name ):
591+ try :
592+ return importlib .util .find_spec (name ) is not None
593+ except (ModuleNotFoundError , ValueError ):
594+ return False
595+
596+ # Check if vLLM has the PDL-related modules before doing internet check
597+ has_utils = _spec_exists ("vllm.lora.ops.triton_ops.utils" )
598+ has_expand_op = _spec_exists ("vllm.lora.ops.triton_ops.lora_expand_op" )
599+ has_shrink_op = _spec_exists ("vllm.lora.ops.triton_ops.lora_shrink_op" )
600+
601+ if not has_utils and not has_expand_op and not has_shrink_op :
602+ # Old vLLM version without PDL support - nothing to patch
603+ return
604+
605+ # Check if vLLM version includes the fix
606+ VLLM_PDL_FIX_VERSION = "0.13.2"
607+ try :
608+ vllm_version = Version (importlib_version ("vllm" ))
609+ if vllm_version > Version (VLLM_PDL_FIX_VERSION ):
610+ logger .info (
611+ f"Unsloth: SM100 ({ sm100_gpu_name } ) detected but vLLM { vllm_version } "
612+ f"should include PDL fix - skipping workaround"
613+ )
614+ return
615+ except Exception as e :
616+ logger .debug (
617+ f"Unsloth: vLLM version check failed ({ e } ), applying PDL workaround."
618+ )
619+
620+ # Apply the PDL fix
621+ os .environ ["TRITON_DISABLE_PDL" ] = "1"
622+
623+ def fake_supports_pdl (* args , ** kwargs ):
624+ return False
625+
626+ patched = []
627+
628+ # First, patch the source module (utils.py) where supports_pdl is defined.
629+ # This is critical because supports_pdl uses @lru_cache - we must clear the
630+ # cache to prevent stale cached results from the original function.
631+ try :
632+ utils_module = importlib .import_module ("vllm.lora.ops.triton_ops.utils" )
633+ if hasattr (utils_module , "supports_pdl" ):
634+ original_fn = utils_module .supports_pdl
635+ if hasattr (original_fn , "cache_clear" ):
636+ original_fn .cache_clear ()
637+ utils_module .supports_pdl = fake_supports_pdl
638+ patched .append ("utils" )
639+ except (ImportError , ModuleNotFoundError , AttributeError ):
640+ pass
641+
642+ # Also patch the consumer modules that import supports_pdl from utils.
643+ # This ensures the patched function is used even if the module was already
644+ # imported before this fix runs.
645+ consumer_modules = {
646+ "lora_expand_op" : "vllm.lora.ops.triton_ops.lora_expand_op" ,
647+ "lora_shrink_op" : "vllm.lora.ops.triton_ops.lora_shrink_op" ,
648+ "fused_moe_lora_op" : "vllm.lora.ops.triton_ops.fused_moe_lora_op" ,
649+ }
650+ for name , path in consumer_modules .items ():
651+ try :
652+ module = importlib .import_module (path )
653+ if hasattr (module , "supports_pdl" ):
654+ module .supports_pdl = fake_supports_pdl
655+ patched .append (name )
656+ except (ImportError , ModuleNotFoundError , AttributeError ):
657+ pass
658+
659+ if patched :
660+ logger .info (
661+ f"Unsloth: Applied PDL fix for SM100 ({ sm100_gpu_name } ) - "
662+ f"patched: { ', ' .join (patched )} "
663+ )
664+ else :
665+ # Just set the env var - vLLM might be an older version without supports_pdl
666+ logger .info (f"Unsloth: Set TRITON_DISABLE_PDL=1 for SM100 ({ sm100_gpu_name } )" )
0 commit comments