Skip to content

Commit 6ce61da

Browse files
committed
Add deepgemm warm up and handle import by ENABLE_JIT_DEEPGEMM
1 parent a64ff5a commit 6ce61da

File tree

3 files changed

+31
-6
lines changed

3 files changed

+31
-6
lines changed

python/sglang/srt/layers/attention/nsa/nsa_indexer.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -230,18 +230,14 @@ def _with_real_sm_count(self):
230230
yield
231231

232232
def _weights_proj_bf16_in_fp32_out(self, x: torch.Tensor) -> torch.Tensor:
233-
try:
234-
from deep_gemm import bf16_gemm_nt
235-
except ImportError:
236-
bf16_gemm_nt = None
237-
if bf16_gemm_nt is not None:
233+
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
238234
weight = self.weights_proj.weight
239235
out = torch.empty(
240236
(x.shape[0], weight.shape[0]),
241237
dtype=torch.float32,
242238
device=x.device,
243239
)
244-
bf16_gemm_nt(x, weight, out)
240+
deep_gemm_wrapper.gemm_nt_bf16bf16f32(x, weight, out)
245241
return out
246242

247243
if _is_hip:

python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ class DeepGemmKernelType(IntEnum):
9696
GROUPED_GEMM_NT_F8F8BF16_MASKED = auto()
9797
GROUPED_GEMM_NT_F8F8BF16_CONTIG = auto()
9898
GEMM_NT_F8F8BF16 = auto()
99+
GEMM_NT_BF16BF16F32 = auto()
99100

100101

101102
_INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dict()
@@ -216,6 +217,7 @@ def create(kernel_type: DeepGemmKernelType, **kwargs):
216217
DeepGemmKernelType.GEMM_NT_F8F8BF16: _NormalWarmupExecutor,
217218
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG: _GroupedContWarmupExecutor,
218219
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: _GroupedMaskedWarmupExecutor,
220+
DeepGemmKernelType.GEMM_NT_BF16BF16F32: _BF16F32WarmupExecutor,
219221
}[kernel_type](**kwargs)
220222

221223
@staticmethod
@@ -235,6 +237,9 @@ def get_memory_requirement(
235237
+ num_groups * 4
236238
+ num_groups * max_m * n * 2
237239
) / _GB
240+
elif kernel_type == DeepGemmKernelType.GEMM_NT_BF16BF16F32:
241+
# bf16 lhs + bf16 rhs + fp32 out
242+
return (max_m * k * 2 + n * k * 2 + max_m * n * 4) / _GB
238243
else:
239244
raise ValueError(f"Invalid kernel type: {kernel_type}")
240245

@@ -317,6 +322,16 @@ def execute(self, m):
317322
)
318323

319324

325+
class _BF16F32WarmupExecutor(_BaseWarmupExecutor):
326+
def __init__(self, max_m: int, n: int, k: int, num_groups: int):
327+
self.lhs = torch.empty((max_m, k), device="cuda", dtype=torch.bfloat16)
328+
self.rhs = torch.empty((n, k), device="cuda", dtype=torch.bfloat16)
329+
self.out = torch.empty((max_m, n), device="cuda", dtype=torch.float32)
330+
331+
def execute(self, m):
332+
deep_gemm.bf16_gemm_nt(self.lhs[:m], self.rhs, self.out[:m])
333+
334+
320335
@contextmanager
321336
def deep_gemm_execution_hook(
322337
m: int, n: int, k: int, num_groups: int, kernel_type: DeepGemmKernelType

python/sglang/srt/layers/deep_gemm_wrapper/entrypoint.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,20 @@ def gemm_nt_f8f8bf16(
102102
)
103103

104104

105+
def gemm_nt_bf16bf16f32(
106+
lhs: torch.Tensor,
107+
rhs: torch.Tensor,
108+
out: torch.Tensor,
109+
):
110+
m, k = lhs.shape
111+
n, _ = rhs.shape
112+
num_groups = 1
113+
kernel_type = compile_utils.DeepGemmKernelType.GEMM_NT_BF16BF16F32
114+
115+
with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
116+
deep_gemm.bf16_gemm_nt(lhs, rhs, out)
117+
118+
105119
def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
106120
compile_utils.update_deep_gemm_config(gpu_id, server_args)
107121

0 commit comments

Comments
 (0)