@@ -855,6 +855,282 @@ index aff05bf42..130359232 100644
855855 else:
856856 logits = torch.matmul(
857857 hidden_states.to(lm_head.weight.dtype), lm_head.weight.T
858+ diff --git a/python/sglang/srt/layers/moe/ep_moe/deepep_bf16_kernels.py b/python/sglang/srt/layers/moe/ep_moe/deepep_bf16_kernels.py
859+ new file mode 100644
860+ index 000000000..7500a3b27
861+ --- /dev/null
862+ +++ b/python/sglang/srt/layers/moe/ep_moe/deepep_bf16_kernels.py
863+ @@ -0,0 +1,150 @@
864+ + """Fused Triton kernels for DeepEP BF16 low-latency MoE decode.
865+ +
866+ + Replaces the naive activation + masking pipeline (5+ CUDA kernels for silu+mul
867+ + and arange+comparison+masked_fill+copy) with a single Triton elementwise kernel,
868+ + while keeping cuBLAS batched GEMM for the matrix multiplies.
869+ +
870+ + Pipeline: bmm → fused_act_mul_masked (in-place) → bmm(out=hidden)
871+ + (3 ops total: 2 cuBLAS + 1 Triton, vs original 7-8 separate CUDA kernels)
872+ + """
873+ +
874+ + import torch
875+ + import triton
876+ + import triton.language as tl
877+ +
878+ +
879+ + @triton.jit
880+ + def _silu_mul_masked_kernel(
881+ + gate_up_ptr,
882+ + masked_m_ptr,
883+ + M,
884+ + N,
885+ + stride_ge,
886+ + stride_gm,
887+ + stride_gn,
888+ + BLOCK: tl.constexpr,
889+ + ):
890+ + """Fused SiLU(gate) * up with per-expert masking, written in-place.
891+ +
892+ + gate_up: [E, M, 2*N] — first N cols are gate, last N cols are up.
893+ + Writes SiLU(gate)*up to gate_up[:,:,:N] in-place.
894+ + Rows m >= masked_m[e] are zeroed.
895+ + """
896+ + expert_id = tl.program_id(1)
897+ + pid = tl.program_id(0)
898+ +
899+ + expert_valid_m = tl.load(masked_m_ptr + expert_id)
900+ +
901+ + offs = pid * BLOCK + tl.arange(0, BLOCK)
902+ + total = M * N
903+ + mask = offs < total
904+ +
905+ + m = offs // N
906+ + n = offs % N
907+ +
908+ + gate_base = gate_up_ptr + expert_id * stride_ge
909+ +
910+ + gate_val = tl.load(
911+ + gate_base + m * stride_gm + n * stride_gn, mask=mask, other=0.0
912+ + )
913+ + up_val = tl.load(
914+ + gate_base + m * stride_gm + (n + N) * stride_gn, mask=mask, other=0.0
915+ + )
916+ +
917+ + gate_f32 = gate_val.to(tl.float32)
918+ + result = (gate_f32 * tl.sigmoid(gate_f32)) * up_val.to(tl.float32)
919+ +
920+ + # Zero invalid rows
921+ + valid = m < expert_valid_m
922+ + result = tl.where(valid, result, 0.0)
923+ +
924+ + tl.store(
925+ + gate_base + m * stride_gm + n * stride_gn,
926+ + result.to(gate_up_ptr.dtype.element_ty),
927+ + mask=mask,
928+ + )
929+ +
930+ +
931+ + @triton.jit
932+ + def _gelu_mul_masked_kernel(
933+ + gate_up_ptr,
934+ + masked_m_ptr,
935+ + M,
936+ + N,
937+ + stride_ge,
938+ + stride_gm,
939+ + stride_gn,
940+ + BLOCK: tl.constexpr,
941+ + ):
942+ + """Fused GELU(gate) * up with per-expert masking, written in-place."""
943+ + expert_id = tl.program_id(1)
944+ + pid = tl.program_id(0)
945+ +
946+ + expert_valid_m = tl.load(masked_m_ptr + expert_id)
947+ +
948+ + offs = pid * BLOCK + tl.arange(0, BLOCK)
949+ + total = M * N
950+ + mask = offs < total
951+ +
952+ + m = offs // N
953+ + n = offs % N
954+ +
955+ + gate_base = gate_up_ptr + expert_id * stride_ge
956+ +
957+ + gate_val = tl.load(
958+ + gate_base + m * stride_gm + n * stride_gn, mask=mask, other=0.0
959+ + )
960+ + up_val = tl.load(
961+ + gate_base + m * stride_gm + (n + N) * stride_gn, mask=mask, other=0.0
962+ + )
963+ +
964+ + g = gate_val.to(tl.float32)
965+ + kAlpha = 0.7978845608028654
966+ + gate_act = 0.5 * g * (1.0 + tl.math.tanh(kAlpha * (g + 0.044715 * g * g * g)))
967+ + result = gate_act * up_val.to(tl.float32)
968+ +
969+ + valid = m < expert_valid_m
970+ + result = tl.where(valid, result, 0.0)
971+ +
972+ + tl.store(
973+ + gate_base + m * stride_gm + n * stride_gn,
974+ + result.to(gate_up_ptr.dtype.element_ty),
975+ + mask=mask,
976+ + )
977+ +
978+ +
979+ + def fused_act_mul_masked_inplace(
980+ + gate_up: torch.Tensor,
981+ + intermediate_size: int,
982+ + masked_m: torch.Tensor,
983+ + use_gelu: bool = False,
984+ + ) -> None:
985+ + """Fused activation + multiply + masking, written in-place to gate_up[:,:,:I].
986+ +
987+ + After this call, gate_up[:, :, :intermediate_size] contains the masked
988+ + activated intermediate, suitable for the down projection GEMM.
989+ +
990+ + Args:
991+ + gate_up: [E, M, 2*I] output of bmm(tokens, w13.T), modified in-place
992+ + intermediate_size: I
993+ + masked_m: [E] per-expert valid token count
994+ + use_gelu: use GELU instead of SiLU
995+ + """
996+ + E, M, _ = gate_up.shape
997+ + N = intermediate_size
998+ +
999+ + total = M * N
1000+ + BLOCK = 1024
1001+ + grid = (triton.cdiv(total, BLOCK), E)
1002+ +
1003+ + kernel = _gelu_mul_masked_kernel if use_gelu else _silu_mul_masked_kernel
1004+ + kernel[grid](
1005+ + gate_up,
1006+ + masked_m,
1007+ + M,
1008+ + N,
1009+ + gate_up.stride(0),
1010+ + gate_up.stride(1),
1011+ + gate_up.stride(2),
1012+ + BLOCK=BLOCK,
1013+ + )
1014+ diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py
1015+ index ebcc696ec..3b527021a 100644
1016+ --- a/python/sglang/srt/layers/moe/ep_moe/layer.py
1017+ +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py
1018+ @@ -132,11 +132,12 @@ class DeepEPMoE(FusedMoE):
1019+ and not _is_npu
1020+ and not (
1021+ get_moe_runner_backend().is_flashinfer_cutedsl()
1022+ + and self.quant_config is not None
1023+ and self.quant_config.get_name() == "modelopt_fp4"
1024+ )
1025+ + and (self.use_fp8_w8a8 or self.use_w4afp8)
1026+ ):
1027+ - # NPU supports low_latency deepep without deepgemm
1028+ - # FP4 quantization with flashinfer_cutedsl also supports low_latency deepep without deepgemm
1029+ + # BF16 models don't need deep_gemm; they use per-expert torch.mm
1030+ assert (
1031+ deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
1032+ ), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
1033+ @@ -154,6 +155,10 @@ class DeepEPMoE(FusedMoE):
1034+ # the last one is invalid rank_id
1035+ self.expert_mask[:-1] = 1
1036+
1037+ + # Set bf16_weights flag on dispatcher so dispatch skips FP8 quantization
1038+ + if not self.use_fp8_w8a8 and not self.use_w4afp8:
1039+ + self.dispatcher.set_quant_config({"bf16_weights": True})
1040+ +
1041+ def forward(
1042+ self,
1043+ hidden_states: torch.Tensor,
1044+ @@ -228,6 +233,8 @@ class DeepEPMoE(FusedMoE):
1045+ elif DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
1046+ if self.use_w4afp8:
1047+ output = self.forward_cutlass_w4afp8(dispatch_output)
1048+ + elif not self.use_fp8_w8a8:
1049+ + output = self.forward_bf16_normal(dispatch_output)
1050+ else:
1051+ assert False, "forward_deepgemm_contiguous is deprecated"
1052+ elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
1053+ @@ -238,6 +245,8 @@ class DeepEPMoE(FusedMoE):
1054+ output = self.forward_flashinfer_cutedsl(dispatch_output)
1055+ elif self.use_w4afp8:
1056+ output = self.forward_cutlass_w4afp8_masked(dispatch_output)
1057+ + elif not self.use_fp8_w8a8:
1058+ + output = self.forward_bf16_ll(dispatch_output)
1059+ else:
1060+ assert False, "forward_deepgemm_masked is deprecated"
1061+
1062+ @@ -341,6 +350,71 @@ class DeepEPMoE(FusedMoE):
1063+ dispatch_output=dispatch_output,
1064+ )
1065+
1066+ + def forward_bf16_normal(
1067+ + self,
1068+ + dispatch_output: DeepEPNormalDispatchOutput,
1069+ + ) -> torch.Tensor:
1070+ + from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
1071+ +
1072+ + hidden_states = dispatch_output.hidden_states
1073+ + topk_ids = dispatch_output.topk_ids
1074+ + topk_weights = dispatch_output.topk_weights
1075+ +
1076+ + if hidden_states.shape[0] == 0:
1077+ + return hidden_states
1078+ +
1079+ + # topk_ids uses local expert IDs (0..num_local_experts-1), -1 for remote.
1080+ + # fused_experts handles -1 via moe_align_block_size filtering.
1081+ + return fused_experts(
1082+ + hidden_states=hidden_states,
1083+ + w1=self.w13_weight,
1084+ + w2=self.w2_weight,
1085+ + topk_output=(topk_weights, topk_ids, None),
1086+ + moe_runner_config=self.moe_runner_config,
1087+ + )
1088+ +
1089+ + def forward_bf16_ll(
1090+ + self,
1091+ + dispatch_output: DeepEPLLDispatchOutput,
1092+ + ) -> torch.Tensor:
1093+ + from sglang.srt.layers.moe.ep_moe.deepep_bf16_kernels import (
1094+ + fused_act_mul_masked_inplace,
1095+ + )
1096+ +
1097+ + hidden_states = dispatch_output.hidden_states
1098+ + masked_m = dispatch_output.masked_m
1099+ + expected_m = dispatch_output.expected_m
1100+ +
1101+ + _, max_tokens, _ = hidden_states.shape
1102+ + if masked_m.numel() == 0 or max_tokens == 0:
1103+ + return hidden_states
1104+ +
1105+ + expected_m = min(expected_m, max_tokens)
1106+ + if expected_m <= 0:
1107+ + return hidden_states
1108+ +
1109+ + tokens = hidden_states[:, :expected_m, :]
1110+ +
1111+ + # 1. Gate+Up GEMM (cuBLAS batched GEMM)
1112+ + gate_up = torch.bmm(tokens, self.w13_weight.transpose(1, 2))
1113+ +
1114+ + # 2. Fused SiLU(gate)*up + masking in-place (1 Triton kernel replaces 6 ops)
1115+ + fused_act_mul_masked_inplace(
1116+ + gate_up,
1117+ + self.intermediate_size_per_partition,
1118+ + masked_m,
1119+ + use_gelu=(self.moe_runner_config.activation == "gelu"),
1120+ + )
1121+ +
1122+ + # 3. Down GEMM into hidden_states (cuBLAS, non-contiguous input is OK)
1123+ + torch.bmm(
1124+ + gate_up[:, :, : self.intermediate_size_per_partition],
1125+ + self.w2_weight.transpose(1, 2),
1126+ + out=hidden_states[:, :expected_m, :],
1127+ + )
1128+ +
1129+ + return hidden_states
1130+ +
1131+ def forward_npu(
1132+ self,
1133+ dispatch_output: Union[DeepEPNormalDispatchOutput, DeepEPLLDispatchOutput],
8581134diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
8591135index ebdbb42c6..714ffbe0e 100644
8601136--- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
@@ -972,6 +1248,68 @@ index 00bd68755..5a3ca8a67 100644
9721248 self.device_cache.capture_fwd_routed_experts(layer_id, topk_ids)
9731249
9741250 def get_routed_experts(
1251+ diff --git a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py
1252+ index 8539639d5..b1f614140 100644
1253+ --- a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py
1254+ +++ b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py
1255+ @@ -388,6 +388,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
1256+ deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
1257+ and not get_moe_runner_backend().is_cutlass()
1258+ and not envs.SGLANG_DEEPEP_BF16_DISPATCH.get()
1259+ + and not self.quant_config.get("bf16_weights", False)
1260+ ):
1261+ # TODO hard code 128 block quant,use fp8 communication
1262+ hidden_states = sglang_per_token_group_quant_fp8(
1263+ @@ -466,7 +467,10 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
1264+ previous_event=previous_event,
1265+ async_finish=self.async_finish,
1266+ allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
1267+ - expert_alignment=128 if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM else 1,
1268+ + expert_alignment=128
1269+ + if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
1270+ + and not self.quant_config.get("bf16_weights", False)
1271+ + else 1,
1272+ config=DeepEPConfig.get_instance().normal_dispatch_config,
1273+ )
1274+ get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
1275+ @@ -491,7 +495,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
1276+ topk_weights: torch.Tensor,
1277+ ):
1278+
1279+ - if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
1280+ + if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu or self.quant_config.get("bf16_weights", False):
1281+ output = hidden_states
1282+ else:
1283+ raise NotImplementedError() # triton runner was supported but it's temporarily disabled
1284+ @@ -551,10 +555,12 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
1285+ buffer = self._get_buffer()
1286+ topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
1287+ topk_ids = topk_ids.to(torch.int64)
1288+ - expected_m = (
1289+ - hidden_states.shape[0] * buffer.group_size * topk_ids.shape[1]
1290+ - + self.num_experts
1291+ - ) // self.num_experts
1292+ + # Use a correctness-preserving upper bound for per-expert token count.
1293+ + # In the worst case, every rank routes all local tokens to the same expert.
1294+ + expected_m = min(
1295+ + hidden_states.shape[0] * buffer.group_size,
1296+ + self.num_max_dispatch_tokens_per_rank * buffer.group_size,
1297+ + )
1298+ hidden_states, masked_m, event, hook = self._dispatch_core(
1299+ hidden_states,
1300+ topk_ids,
1301+ @@ -609,7 +615,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
1302+ input_global_scale = self.quant_config.get("input_global_scale", None)
1303+ if input_global_scale is not None:
1304+ use_nvfp4 = True
1305+ - elif not envs.SGLANG_DEEPEP_BF16_DISPATCH.get():
1306+ + elif (
1307+ + not envs.SGLANG_DEEPEP_BF16_DISPATCH.get()
1308+ + and not self.quant_config.get("bf16_weights", False)
1309+ + ):
1310+ use_fp8 = True
1311+
1312+ buffer = self._get_buffer()
9751313diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
9761314index 4cbfed6f9..88b452744 100644
9771315--- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
0 commit comments