Skip to content

Commit 0424b2d

Browse files
committed
Implement deepgemm bf16 in fp32 out for indexer weights_proj
1 parent db34c1c commit 0424b2d

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

python/sglang/srt/layers/attention/nsa/nsa_indexer.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -229,21 +229,31 @@ def _with_real_sm_count(self):
229229
else:
230230
yield
231231

232-
@torch.compile(dynamic=True) if not _is_hip else lambda f: f
233-
def _project_and_scale_head_gates(self, x: torch.Tensor):
232+
def _weights_proj_bf16_in_fp32_out(self, x: torch.Tensor) -> torch.Tensor:
233+
if _is_cuda and hasattr(deep_gemm, "bf16_gemm_nt"):
234+
weight = self.weights_proj.weight
235+
out = torch.empty(
236+
(x.shape[0], weight.shape[0]),
237+
dtype=torch.float32,
238+
device=x.device,
239+
)
240+
deep_gemm.bf16_gemm_nt(x, weight, out)
241+
return out
242+
234243
if _is_hip:
235244
x = x.to(self.weights_proj.weight.dtype)
236245
weights, _ = self.weights_proj(x)
237-
weights = weights.float()
246+
return weights.float()
247+
248+
@torch.compile(dynamic=True) if not _is_hip else lambda f: f
249+
def _project_and_scale_head_gates(self, x: torch.Tensor):
250+
weights = self._weights_proj_bf16_in_fp32_out(x)
238251
weights = weights * self.n_heads**-0.5
239252
return weights
240253

241254
@torch.compile(dynamic=True) if not _is_hip else lambda f: f
242255
def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor):
243-
if _is_hip:
244-
x = x.to(self.weights_proj.weight.dtype)
245-
weights, _ = self.weights_proj(x)
246-
weights = weights.float()
256+
weights = self._weights_proj_bf16_in_fp32_out(x)
247257
weights = weights * self.n_heads**-0.5
248258
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
249259
return weights

0 commit comments

Comments
 (0)