6868logger = init_logger (__name__ )
6969
7070
71- def _has_cutlass_dsl_cu13 () -> bool :
72- """Whether the CUDA-13 CuTe-DSL shared libs are installed.
73- """
74- try :
75- from importlib .metadata import distribution
76- except ImportError :
77- return False
78- try :
79- distribution ("nvidia-cutlass-dsl-libs-cu13" )
80- except Exception :
81- return False
82- return True
83-
84-
8571def _should_use_flashinfer_gdn_prefill (
8672 backend : str , head_k_dim : int | None
8773) -> bool :
@@ -106,7 +92,9 @@ def _should_use_flashinfer_gdn_prefill(
10692 return False # Neither Hopper nor Blackwell.
10793 if head_k_dim != 128 :
10894 return False
109- if not _has_cutlass_dsl_cu13 ():
95+ if current_platform .get_cuda_runtime_major () < 13 :
96+ return False
97+ if not current_platform .has_cutlass_dsl_cu13 ():
11098 return False
11199 return True
112100
@@ -121,7 +109,7 @@ def _log_gdn_backend_decision(
121109 device_cap = (
122110 str (current_platform .get_device_capability ()) if is_cuda else "n/a"
123111 )
124- cutlass_dsl_cu13_installed = _has_cutlass_dsl_cu13 ()
112+ cutlass_dsl_cu13_installed = current_platform . has_cutlass_dsl_cu13 ()
125113 logger .info_once (
126114 "GDN prefill backend inputs:\n "
127115 " requested=%s\n "
@@ -202,6 +190,12 @@ def __init__(self, head_k_dim: int | None = None) -> None:
202190 backend = str (backend_cfg ).strip ().lower ()
203191
204192 use_flashinfer = _should_use_flashinfer_gdn_prefill (backend , head_k_dim )
193+ if backend == "flashinfer" and not use_flashinfer :
194+ logger .warning_once (
195+ "GDN prefill backend 'flashinfer' is selected but "
196+ "cannot use this kernel on the current platform. "
197+ "Falling back to Triton/FLA."
198+ )
205199 _log_gdn_backend_decision (backend , head_k_dim , use_flashinfer )
206200
207201 self ._forward_method = (
0 commit comments