Skip to content

Commit f603594

Browse files
author
Copilot
committed
[docker] supports bf16 deepep
1 parent a2b16da commit f603594

File tree

2 files changed

+339
-1
lines changed

2 files changed

+339
-1
lines changed

docker/patch/latest/sglang.patch

Lines changed: 338 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -855,6 +855,282 @@ index aff05bf42..130359232 100644
855855
else:
856856
logits = torch.matmul(
857857
hidden_states.to(lm_head.weight.dtype), lm_head.weight.T
858+
diff --git a/python/sglang/srt/layers/moe/ep_moe/deepep_bf16_kernels.py b/python/sglang/srt/layers/moe/ep_moe/deepep_bf16_kernels.py
859+
new file mode 100644
860+
index 000000000..7500a3b27
861+
--- /dev/null
862+
+++ b/python/sglang/srt/layers/moe/ep_moe/deepep_bf16_kernels.py
863+
@@ -0,0 +1,150 @@
864+
+"""Fused Triton kernels for DeepEP BF16 low-latency MoE decode.
865+
+
866+
+Replaces the naive activation + masking pipeline (5+ CUDA kernels for silu+mul
867+
+and arange+comparison+masked_fill+copy) with a single Triton elementwise kernel,
868+
+while keeping cuBLAS batched GEMM for the matrix multiplies.
869+
+
870+
+Pipeline: bmm → fused_act_mul_masked (in-place) → bmm(out=hidden)
871+
+ (3 ops total: 2 cuBLAS + 1 Triton, vs original 7-8 separate CUDA kernels)
872+
+"""
873+
+
874+
+import torch
875+
+import triton
876+
+import triton.language as tl
877+
+
878+
+
879+
+@triton.jit
880+
+def _silu_mul_masked_kernel(
881+
+ gate_up_ptr,
882+
+ masked_m_ptr,
883+
+ M,
884+
+ N,
885+
+ stride_ge,
886+
+ stride_gm,
887+
+ stride_gn,
888+
+ BLOCK: tl.constexpr,
889+
+):
890+
+ """Fused SiLU(gate) * up with per-expert masking, written in-place.
891+
+
892+
+ gate_up: [E, M, 2*N] — first N cols are gate, last N cols are up.
893+
+ Writes SiLU(gate)*up to gate_up[:,:,:N] in-place.
894+
+ Rows m >= masked_m[e] are zeroed.
895+
+ """
896+
+ expert_id = tl.program_id(1)
897+
+ pid = tl.program_id(0)
898+
+
899+
+ expert_valid_m = tl.load(masked_m_ptr + expert_id)
900+
+
901+
+ offs = pid * BLOCK + tl.arange(0, BLOCK)
902+
+ total = M * N
903+
+ mask = offs < total
904+
+
905+
+ m = offs // N
906+
+ n = offs % N
907+
+
908+
+ gate_base = gate_up_ptr + expert_id * stride_ge
909+
+
910+
+ gate_val = tl.load(
911+
+ gate_base + m * stride_gm + n * stride_gn, mask=mask, other=0.0
912+
+ )
913+
+ up_val = tl.load(
914+
+ gate_base + m * stride_gm + (n + N) * stride_gn, mask=mask, other=0.0
915+
+ )
916+
+
917+
+ gate_f32 = gate_val.to(tl.float32)
918+
+ result = (gate_f32 * tl.sigmoid(gate_f32)) * up_val.to(tl.float32)
919+
+
920+
+ # Zero invalid rows
921+
+ valid = m < expert_valid_m
922+
+ result = tl.where(valid, result, 0.0)
923+
+
924+
+ tl.store(
925+
+ gate_base + m * stride_gm + n * stride_gn,
926+
+ result.to(gate_up_ptr.dtype.element_ty),
927+
+ mask=mask,
928+
+ )
929+
+
930+
+
931+
+@triton.jit
932+
+def _gelu_mul_masked_kernel(
933+
+ gate_up_ptr,
934+
+ masked_m_ptr,
935+
+ M,
936+
+ N,
937+
+ stride_ge,
938+
+ stride_gm,
939+
+ stride_gn,
940+
+ BLOCK: tl.constexpr,
941+
+):
942+
+ """Fused GELU(gate) * up with per-expert masking, written in-place."""
943+
+ expert_id = tl.program_id(1)
944+
+ pid = tl.program_id(0)
945+
+
946+
+ expert_valid_m = tl.load(masked_m_ptr + expert_id)
947+
+
948+
+ offs = pid * BLOCK + tl.arange(0, BLOCK)
949+
+ total = M * N
950+
+ mask = offs < total
951+
+
952+
+ m = offs // N
953+
+ n = offs % N
954+
+
955+
+ gate_base = gate_up_ptr + expert_id * stride_ge
956+
+
957+
+ gate_val = tl.load(
958+
+ gate_base + m * stride_gm + n * stride_gn, mask=mask, other=0.0
959+
+ )
960+
+ up_val = tl.load(
961+
+ gate_base + m * stride_gm + (n + N) * stride_gn, mask=mask, other=0.0
962+
+ )
963+
+
964+
+ g = gate_val.to(tl.float32)
965+
+ kAlpha = 0.7978845608028654
966+
+ gate_act = 0.5 * g * (1.0 + tl.math.tanh(kAlpha * (g + 0.044715 * g * g * g)))
967+
+ result = gate_act * up_val.to(tl.float32)
968+
+
969+
+ valid = m < expert_valid_m
970+
+ result = tl.where(valid, result, 0.0)
971+
+
972+
+ tl.store(
973+
+ gate_base + m * stride_gm + n * stride_gn,
974+
+ result.to(gate_up_ptr.dtype.element_ty),
975+
+ mask=mask,
976+
+ )
977+
+
978+
+
979+
+def fused_act_mul_masked_inplace(
980+
+ gate_up: torch.Tensor,
981+
+ intermediate_size: int,
982+
+ masked_m: torch.Tensor,
983+
+ use_gelu: bool = False,
984+
+) -> None:
985+
+ """Fused activation + multiply + masking, written in-place to gate_up[:,:,:I].
986+
+
987+
+ After this call, gate_up[:, :, :intermediate_size] contains the masked
988+
+ activated intermediate, suitable for the down projection GEMM.
989+
+
990+
+ Args:
991+
+ gate_up: [E, M, 2*I] output of bmm(tokens, w13.T), modified in-place
992+
+ intermediate_size: I
993+
+ masked_m: [E] per-expert valid token count
994+
+ use_gelu: use GELU instead of SiLU
995+
+ """
996+
+ E, M, _ = gate_up.shape
997+
+ N = intermediate_size
998+
+
999+
+ total = M * N
1000+
+ BLOCK = 1024
1001+
+ grid = (triton.cdiv(total, BLOCK), E)
1002+
+
1003+
+ kernel = _gelu_mul_masked_kernel if use_gelu else _silu_mul_masked_kernel
1004+
+ kernel[grid](
1005+
+ gate_up,
1006+
+ masked_m,
1007+
+ M,
1008+
+ N,
1009+
+ gate_up.stride(0),
1010+
+ gate_up.stride(1),
1011+
+ gate_up.stride(2),
1012+
+ BLOCK=BLOCK,
1013+
+ )
1014+
diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py
1015+
index ebcc696ec..3b527021a 100644
1016+
--- a/python/sglang/srt/layers/moe/ep_moe/layer.py
1017+
+++ b/python/sglang/srt/layers/moe/ep_moe/layer.py
1018+
@@ -132,11 +132,12 @@ class DeepEPMoE(FusedMoE):
1019+
and not _is_npu
1020+
and not (
1021+
get_moe_runner_backend().is_flashinfer_cutedsl()
1022+
+ and self.quant_config is not None
1023+
and self.quant_config.get_name() == "modelopt_fp4"
1024+
)
1025+
+ and (self.use_fp8_w8a8 or self.use_w4afp8)
1026+
):
1027+
- # NPU supports low_latency deepep without deepgemm
1028+
- # FP4 quantization with flashinfer_cutedsl also supports low_latency deepep without deepgemm
1029+
+ # BF16 models don't need deep_gemm; they use per-expert torch.mm
1030+
assert (
1031+
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
1032+
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
1033+
@@ -154,6 +155,10 @@ class DeepEPMoE(FusedMoE):
1034+
# the last one is invalid rank_id
1035+
self.expert_mask[:-1] = 1
1036+
1037+
+ # Set bf16_weights flag on dispatcher so dispatch skips FP8 quantization
1038+
+ if not self.use_fp8_w8a8 and not self.use_w4afp8:
1039+
+ self.dispatcher.set_quant_config({"bf16_weights": True})
1040+
+
1041+
def forward(
1042+
self,
1043+
hidden_states: torch.Tensor,
1044+
@@ -228,6 +233,8 @@ class DeepEPMoE(FusedMoE):
1045+
elif DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
1046+
if self.use_w4afp8:
1047+
output = self.forward_cutlass_w4afp8(dispatch_output)
1048+
+ elif not self.use_fp8_w8a8:
1049+
+ output = self.forward_bf16_normal(dispatch_output)
1050+
else:
1051+
assert False, "forward_deepgemm_contiguous is deprecated"
1052+
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
1053+
@@ -238,6 +245,8 @@ class DeepEPMoE(FusedMoE):
1054+
output = self.forward_flashinfer_cutedsl(dispatch_output)
1055+
elif self.use_w4afp8:
1056+
output = self.forward_cutlass_w4afp8_masked(dispatch_output)
1057+
+ elif not self.use_fp8_w8a8:
1058+
+ output = self.forward_bf16_ll(dispatch_output)
1059+
else:
1060+
assert False, "forward_deepgemm_masked is deprecated"
1061+
1062+
@@ -341,6 +350,71 @@ class DeepEPMoE(FusedMoE):
1063+
dispatch_output=dispatch_output,
1064+
)
1065+
1066+
+ def forward_bf16_normal(
1067+
+ self,
1068+
+ dispatch_output: DeepEPNormalDispatchOutput,
1069+
+ ) -> torch.Tensor:
1070+
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
1071+
+
1072+
+ hidden_states = dispatch_output.hidden_states
1073+
+ topk_ids = dispatch_output.topk_ids
1074+
+ topk_weights = dispatch_output.topk_weights
1075+
+
1076+
+ if hidden_states.shape[0] == 0:
1077+
+ return hidden_states
1078+
+
1079+
+ # topk_ids uses local expert IDs (0..num_local_experts-1), -1 for remote.
1080+
+ # fused_experts handles -1 via moe_align_block_size filtering.
1081+
+ return fused_experts(
1082+
+ hidden_states=hidden_states,
1083+
+ w1=self.w13_weight,
1084+
+ w2=self.w2_weight,
1085+
+ topk_output=(topk_weights, topk_ids, None),
1086+
+ moe_runner_config=self.moe_runner_config,
1087+
+ )
1088+
+
1089+
+ def forward_bf16_ll(
1090+
+ self,
1091+
+ dispatch_output: DeepEPLLDispatchOutput,
1092+
+ ) -> torch.Tensor:
1093+
+ from sglang.srt.layers.moe.ep_moe.deepep_bf16_kernels import (
1094+
+ fused_act_mul_masked_inplace,
1095+
+ )
1096+
+
1097+
+ hidden_states = dispatch_output.hidden_states
1098+
+ masked_m = dispatch_output.masked_m
1099+
+ expected_m = dispatch_output.expected_m
1100+
+
1101+
+ _, max_tokens, _ = hidden_states.shape
1102+
+ if masked_m.numel() == 0 or max_tokens == 0:
1103+
+ return hidden_states
1104+
+
1105+
+ expected_m = min(expected_m, max_tokens)
1106+
+ if expected_m <= 0:
1107+
+ return hidden_states
1108+
+
1109+
+ tokens = hidden_states[:, :expected_m, :]
1110+
+
1111+
+ # 1. Gate+Up GEMM (cuBLAS batched GEMM)
1112+
+ gate_up = torch.bmm(tokens, self.w13_weight.transpose(1, 2))
1113+
+
1114+
+ # 2. Fused SiLU(gate)*up + masking in-place (1 Triton kernel replaces 6 ops)
1115+
+ fused_act_mul_masked_inplace(
1116+
+ gate_up,
1117+
+ self.intermediate_size_per_partition,
1118+
+ masked_m,
1119+
+ use_gelu=(self.moe_runner_config.activation == "gelu"),
1120+
+ )
1121+
+
1122+
+ # 3. Down GEMM into hidden_states (cuBLAS, non-contiguous input is OK)
1123+
+ torch.bmm(
1124+
+ gate_up[:, :, : self.intermediate_size_per_partition],
1125+
+ self.w2_weight.transpose(1, 2),
1126+
+ out=hidden_states[:, :expected_m, :],
1127+
+ )
1128+
+
1129+
+ return hidden_states
1130+
+
1131+
def forward_npu(
1132+
self,
1133+
dispatch_output: Union[DeepEPNormalDispatchOutput, DeepEPLLDispatchOutput],
8581134
diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
8591135
index ebdbb42c6..714ffbe0e 100644
8601136
--- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
@@ -972,6 +1248,68 @@ index 00bd68755..5a3ca8a67 100644
9721248
self.device_cache.capture_fwd_routed_experts(layer_id, topk_ids)
9731249

9741250
def get_routed_experts(
1251+
diff --git a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py
1252+
index 8539639d5..b1f614140 100644
1253+
--- a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py
1254+
+++ b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py
1255+
@@ -388,6 +388,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
1256+
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
1257+
and not get_moe_runner_backend().is_cutlass()
1258+
and not envs.SGLANG_DEEPEP_BF16_DISPATCH.get()
1259+
+ and not self.quant_config.get("bf16_weights", False)
1260+
):
1261+
# TODO hard code 128 block quant,use fp8 communication
1262+
hidden_states = sglang_per_token_group_quant_fp8(
1263+
@@ -466,7 +467,10 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
1264+
previous_event=previous_event,
1265+
async_finish=self.async_finish,
1266+
allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
1267+
- expert_alignment=128 if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM else 1,
1268+
+ expert_alignment=128
1269+
+ if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
1270+
+ and not self.quant_config.get("bf16_weights", False)
1271+
+ else 1,
1272+
config=DeepEPConfig.get_instance().normal_dispatch_config,
1273+
)
1274+
get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
1275+
@@ -491,7 +495,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
1276+
topk_weights: torch.Tensor,
1277+
):
1278+
1279+
- if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
1280+
+ if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu or self.quant_config.get("bf16_weights", False):
1281+
output = hidden_states
1282+
else:
1283+
raise NotImplementedError() # triton runner was supported but it's temporarily disabled
1284+
@@ -551,10 +555,12 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
1285+
buffer = self._get_buffer()
1286+
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
1287+
topk_ids = topk_ids.to(torch.int64)
1288+
- expected_m = (
1289+
- hidden_states.shape[0] * buffer.group_size * topk_ids.shape[1]
1290+
- + self.num_experts
1291+
- ) // self.num_experts
1292+
+ # Use a correctness-preserving upper bound for per-expert token count.
1293+
+ # In the worst case, every rank routes all local tokens to the same expert.
1294+
+ expected_m = min(
1295+
+ hidden_states.shape[0] * buffer.group_size,
1296+
+ self.num_max_dispatch_tokens_per_rank * buffer.group_size,
1297+
+ )
1298+
hidden_states, masked_m, event, hook = self._dispatch_core(
1299+
hidden_states,
1300+
topk_ids,
1301+
@@ -609,7 +615,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
1302+
input_global_scale = self.quant_config.get("input_global_scale", None)
1303+
if input_global_scale is not None:
1304+
use_nvfp4 = True
1305+
- elif not envs.SGLANG_DEEPEP_BF16_DISPATCH.get():
1306+
+ elif (
1307+
+ not envs.SGLANG_DEEPEP_BF16_DISPATCH.get()
1308+
+ and not self.quant_config.get("bf16_weights", False)
1309+
+ ):
1310+
use_fp8 = True
1311+
1312+
buffer = self._get_buffer()
9751313
diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
9761314
index 4cbfed6f9..88b452744 100644
9771315
--- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py

docker/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
nightly-dev-20260227a
1+
nightly-dev-20260302a

0 commit comments

Comments
 (0)