We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 0424b2d commit a64ff5aCopy full SHA for a64ff5a
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
@@ -230,14 +230,18 @@ def _with_real_sm_count(self):
230
yield
231
232
def _weights_proj_bf16_in_fp32_out(self, x: torch.Tensor) -> torch.Tensor:
233
- if _is_cuda and hasattr(deep_gemm, "bf16_gemm_nt"):
+ try:
234
+ from deep_gemm import bf16_gemm_nt
235
+ except ImportError:
236
+ bf16_gemm_nt = None
237
+ if bf16_gemm_nt is not None:
238
weight = self.weights_proj.weight
239
out = torch.empty(
240
(x.shape[0], weight.shape[0]),
241
dtype=torch.float32,
242
device=x.device,
243
)
- deep_gemm.bf16_gemm_nt(x, weight, out)
244
+ bf16_gemm_nt(x, weight, out)
245
return out
246
247
if _is_hip:
0 commit comments