Skip to content

Commit 57a19c7

Browse files
committed
[GDN] Address review feedback from Gemini
Signed-off-by: Artem Perevedentsev <aperevedents@nvidia.com>
1 parent f7db516 commit 57a19c7

2 files changed

Lines changed: 29 additions & 16 deletions

File tree

vllm/model_executor/layers/mamba/gdn_linear_attn.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -68,20 +68,6 @@
6868
logger = init_logger(__name__)
6969

7070

71-
def _has_cutlass_dsl_cu13() -> bool:
72-
"""Whether the CUDA-13 CuTe-DSL shared libs are installed.
73-
"""
74-
try:
75-
from importlib.metadata import distribution
76-
except ImportError:
77-
return False
78-
try:
79-
distribution("nvidia-cutlass-dsl-libs-cu13")
80-
except Exception:
81-
return False
82-
return True
83-
84-
8571
def _should_use_flashinfer_gdn_prefill(
8672
backend: str, head_k_dim: int | None
8773
) -> bool:
@@ -106,7 +92,9 @@ def _should_use_flashinfer_gdn_prefill(
10692
return False # Neither Hopper nor Blackwell.
10793
if head_k_dim != 128:
10894
return False
109-
if not _has_cutlass_dsl_cu13():
95+
if current_platform.get_cuda_runtime_major() < 13:
96+
return False
97+
if not current_platform.has_cutlass_dsl_cu13():
11098
return False
11199
return True
112100

@@ -121,7 +109,7 @@ def _log_gdn_backend_decision(
121109
device_cap = (
122110
str(current_platform.get_device_capability()) if is_cuda else "n/a"
123111
)
124-
cutlass_dsl_cu13_installed = _has_cutlass_dsl_cu13()
112+
cutlass_dsl_cu13_installed = current_platform.has_cutlass_dsl_cu13()
125113
logger.info_once(
126114
"GDN prefill backend inputs:\n"
127115
" requested=%s\n"
@@ -202,6 +190,12 @@ def __init__(self, head_k_dim: int | None = None) -> None:
202190
backend = str(backend_cfg).strip().lower()
203191

204192
use_flashinfer = _should_use_flashinfer_gdn_prefill(backend, head_k_dim)
193+
if backend == "flashinfer" and not use_flashinfer:
194+
logger.warning_once(
195+
"GDN prefill backend 'flashinfer' is selected but "
196+
"cannot use this kernel on the current platform. "
197+
"Falling back to Triton/FLA."
198+
)
205199
_log_gdn_backend_decision(backend, head_k_dim, use_flashinfer)
206200

207201
self._forward_method = (

vllm/platforms/interface.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,25 @@ def is_device_capability_family(
359359
return False
360360
return (current_capability.to_int() // 10) == (capability // 10)
361361

362+
@classmethod
363+
def get_cuda_runtime_major(cls) -> int:
364+
"""Major ``torch.version.cuda`` version, or ``0`` if undetermined."""
365+
major = (torch.version.cuda or "0").split(".", 1)[0]
366+
return int(major) if major.isdigit() else 0
367+
368+
@classmethod
369+
def has_cutlass_dsl_cu13(cls) -> bool:
370+
"""Whether ``nvidia-cutlass-dsl-libs-cu13`` is installed."""
371+
try:
372+
from importlib.metadata import distribution
373+
except ImportError:
374+
return False
375+
try:
376+
distribution("nvidia-cutlass-dsl-libs-cu13")
377+
except Exception:
378+
return False
379+
return True
380+
362381
@classmethod
363382
def get_device_name(cls, device_id: int = 0) -> str:
364383
"""Get the name of a device."""

0 commit comments

Comments
 (0)