Skip to content

Commit a64ff5a

Browse files
committed
Add explicit try import and fallback pattern
1 parent 0424b2d commit a64ff5a

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

python/sglang/srt/layers/attention/nsa/nsa_indexer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,14 +230,18 @@ def _with_real_sm_count(self):
230230
yield
231231

232232
def _weights_proj_bf16_in_fp32_out(self, x: torch.Tensor) -> torch.Tensor:
233-
if _is_cuda and hasattr(deep_gemm, "bf16_gemm_nt"):
233+
try:
234+
from deep_gemm import bf16_gemm_nt
235+
except ImportError:
236+
bf16_gemm_nt = None
237+
if bf16_gemm_nt is not None:
234238
weight = self.weights_proj.weight
235239
out = torch.empty(
236240
(x.shape[0], weight.shape[0]),
237241
dtype=torch.float32,
238242
device=x.device,
239243
)
240-
deep_gemm.bf16_gemm_nt(x, weight, out)
244+
bf16_gemm_nt(x, weight, out)
241245
return out
242246

243247
if _is_hip:

0 commit comments

Comments
 (0)