Skip to content

Commit 0dfd54d

Browse files
kkHuang-amdwunhuangwghuang
authored
Optimized deepseek-v3/r1 model performance on mxfp4 run (sgl-project#9671)
Co-authored-by: wunhuang <wunhuang@amd.com> Co-authored-by: wghuang <wghuang@amd.com>
1 parent bcbeed7 commit 0dfd54d

File tree

7 files changed

+458
-62
lines changed

7 files changed

+458
-62
lines changed

python/sglang/srt/layers/communicator.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,22 @@
4242
)
4343
from sglang.srt.managers.schedule_batch import global_server_args_dict
4444
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
45-
from sglang.srt.utils import is_cuda, is_flashinfer_available, is_sm100_supported
45+
from sglang.srt.utils import (
46+
get_bool_env_var,
47+
is_cuda,
48+
is_flashinfer_available,
49+
is_gfx95_supported,
50+
is_hip,
51+
is_sm100_supported,
52+
)
4653

4754
_is_flashinfer_available = is_flashinfer_available()
4855
_is_sm100_supported = is_cuda() and is_sm100_supported()
56+
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
57+
_is_gfx95_supported = is_gfx95_supported()
58+
59+
if _use_aiter and _is_gfx95_supported:
60+
from sglang.srt.layers.quantization.rocm_mxfp4_utils import fused_rms_mxfp4_quant
4961

5062
FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048
5163

@@ -201,6 +213,7 @@ def prepare_attn(
201213
hidden_states: torch.Tensor,
202214
residual: torch.Tensor,
203215
forward_batch: ForwardBatch,
216+
qaunt_format: str = "",
204217
):
205218
if hidden_states.shape[0] == 0:
206219
residual = hidden_states
@@ -218,11 +231,34 @@ def prepare_attn(
218231
else:
219232
if residual is None:
220233
residual = hidden_states
221-
hidden_states = self.input_layernorm(hidden_states)
234+
235+
if _use_aiter and _is_gfx95_supported and ("mxfp4" in qaunt_format):
236+
hidden_states = fused_rms_mxfp4_quant(
237+
hidden_states,
238+
self.input_layernorm.weight,
239+
self.input_layernorm.variance_epsilon,
240+
None,
241+
None,
242+
None,
243+
None,
244+
)
245+
else:
246+
hidden_states = self.input_layernorm(hidden_states)
222247
else:
223-
hidden_states, residual = self.input_layernorm(
224-
hidden_states, residual
225-
)
248+
if _use_aiter and _is_gfx95_supported and ("mxfp4" in qaunt_format):
249+
hidden_states, residual = fused_rms_mxfp4_quant(
250+
hidden_states,
251+
self.input_layernorm.weight,
252+
self.input_layernorm.variance_epsilon,
253+
None,
254+
None,
255+
None,
256+
residual,
257+
)
258+
else:
259+
hidden_states, residual = self.input_layernorm(
260+
hidden_states, residual
261+
)
226262

227263
hidden_states = self._communicate_simple_fn(
228264
hidden_states=hidden_states,

python/sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py

Lines changed: 49 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from aiter.ops.gemm_op_a4w4 import gemm_a4w4
99
from aiter.ops.shuffle import shuffle_weight
1010
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
11+
from aiter.ops.triton.gemm_afp4wfp4_pre_quant_atomic import gemm_afp4wfp4_pre_quant
1112
from aiter.ops.triton.quant import dynamic_mxfp4_quant
1213
from aiter.utility import dtypes
1314
from aiter.utility.fp4_utils import e8m0_shuffle
@@ -38,15 +39,6 @@ def get_min_capability(cls) -> int:
3839
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
3940
return
4041

41-
# for aiter implement
42-
# wshuffle = shuffle_weight(layer.weight.data, layout=(16, 16))
43-
# w_scales_shuffle = e8m0_shuffle(layer.weight_scale.data).view(dtypes.fp8_e8m0)
44-
45-
# layer.weight = torch.nn.Parameter(wshuffle,
46-
# requires_grad=False)
47-
# layer.weight_scale = torch.nn.Parameter(w_scales_shuffle,
48-
# requires_grad=False)
49-
5042
def create_weights(
5143
self,
5244
layer: torch.nn.Module,
@@ -93,26 +85,53 @@ def apply_weights(
9385
x: torch.Tensor,
9486
bias: Optional[torch.Tensor] = None,
9587
) -> torch.Tensor:
96-
97-
out_dtype = x.dtype
98-
# M = x.shape[0]
99-
# N = layer.weight.shape[0]
100-
101-
# quant_func = aiter.get_triton_quant(aiter.QuantType.per_1x32)
102-
# x, x_scales_shuffle = quant_func(x, shuffle=True)
103-
104-
# y = torch.zeros((M + 255) // 256 * 256, N, device=x.device, dtype=self.out_dtype)
105-
106-
# out = gemm_a4w4(x, layer.weight.data, x_scales_shuffle, layer.weight_scale.data, y, bias=bias)
107-
108-
# return out[:M]
109-
110-
# triton implement
111-
x_q, x_s = dynamic_mxfp4_quant(x)
112-
y = torch.empty(
113-
x_q.shape[0], layer.weight.shape[0], device=x_q.device, dtype=out_dtype
88+
# This path does not have support for bias currently
89+
assert bias is None, "bias is not supported"
90+
91+
three_d = False
92+
x_s = None
93+
y = None
94+
if isinstance(x, tuple):
95+
assert len(x) in [
96+
2,
97+
3,
98+
], "For tuple input, only (x, x_s) or (x, x_s, y) formats are accepted"
99+
if len(x) == 2:
100+
x, x_s = x
101+
elif len(x) == 3:
102+
x, x_s, y = x
103+
104+
use_fused_quant_gemm = (
105+
x_s is None and y is not None and layer.weight.shape[0] == y.shape[1]
114106
)
115107

116-
out = gemm_afp4wfp4(x_q, layer.weight, x_s, layer.weight_scale, out_dtype, y)
117-
118-
return out
108+
if x.dim() == 3:
109+
three_d = True
110+
x = x.view(-1, x.shape[-1])
111+
output_shape = [*x.shape[:-1], layer.weight.shape[0]]
112+
113+
# use_fused_quant_gemm = true, x_q is a bf16/fp16 num
114+
# x_s is not None = true, x_q is uint8 num
115+
if use_fused_quant_gemm or x_s is not None:
116+
x_q = x
117+
else:
118+
x_q, x_s = dynamic_mxfp4_quant(x)
119+
120+
if y is None:
121+
y = torch.empty(
122+
x_q.shape[0],
123+
layer.weight.shape[0],
124+
device=x_q.device,
125+
dtype=self.out_dtype,
126+
)
127+
128+
if use_fused_quant_gemm:
129+
gemm_afp4wfp4_pre_quant(x_q, layer.weight, layer.weight_scale, y.dtype, y)
130+
y = y.to(x.dtype)
131+
else:
132+
gemm_afp4wfp4(x_q, layer.weight, x_s, layer.weight_scale, self.out_dtype, y)
133+
134+
if three_d:
135+
return y.view(*output_shape)
136+
137+
return y

python/sglang/srt/layers/quantization/quark/utils.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
from types import MappingProxyType
66
from typing import Any, Optional
77

8+
import torch
9+
from aiter.ops.triton.quant import dynamic_mxfp4_quant
10+
from torch import nn
11+
812

913
def deep_compare(dict1: Any, dict2: Any) -> bool:
1014
if type(dict1) is not type(dict2):
@@ -105,3 +109,96 @@ def _is_equal_or_regex_match(
105109
elif target == value:
106110
return True
107111
return False
112+
113+
114+
# utility for tensor dims > 2 cases
115+
def b_dynamic_mxfp4_quant(x):
116+
h, b, d = x.shape
117+
x, x_scales = dynamic_mxfp4_quant(x.reshape(-1, d))
118+
return x.view(h, b, d // 2), x_scales.view(h, b, d // 32)
119+
120+
121+
def mxfp4_to_f32(x, is_threed):
122+
# 2 because we pack fp4 in uint8.
123+
x = x.repeat_interleave(2, dim=-1)
124+
if is_threed:
125+
x[..., ::2] = x[..., ::2] & 0xF
126+
x[..., 1::2] = x[..., 1::2] >> 4
127+
else:
128+
x[:, ::2] = x[:, ::2] & 0xF
129+
x[:, 1::2] = x[:, 1::2] >> 4
130+
131+
mxfp4_list = [
132+
0.0,
133+
0.5,
134+
1.0,
135+
1.5,
136+
2.0,
137+
3.0,
138+
4.0,
139+
6.0,
140+
-0.0,
141+
-0.5,
142+
-1.0,
143+
-1.5,
144+
-2.0,
145+
-3.0,
146+
-4.0,
147+
-6.0,
148+
]
149+
mxfp4_in_f32 = torch.tensor(mxfp4_list, dtype=torch.float32, device="cuda")
150+
return mxfp4_in_f32[x.long()]
151+
152+
153+
def e8m0_to_f32(x):
154+
# Convert the input tensor `x` (assumed to be in e8m0 format) to float32.
155+
# e8m0 is a custom 8-bit floating point format with 8 bits for exponent, 0 for mantissa.
156+
# This means the value is essentially 2^(exponent - 127), similar to how IEEE-754 stores floats.
157+
158+
# Convert x to float32 for computation, and compute the power of 2 by subtracting the bias (127).
159+
x_f32 = 2 ** ((x.to(torch.float32)) - 127)
160+
161+
# If the exponent value was 255 (i.e., 2^(128)), this is a special case usually used to represent NaN or Inf.
162+
# Since this custom format has no mantissa, treat 2^128 as NaN.
163+
x_f32[x_f32 == 128] = float("nan")
164+
return x_f32
165+
166+
167+
def quark_post_load_weights(self_attn: nn.Module, w: torch.Tensor, quant_format: str):
168+
if "mxfp4" in quant_format:
169+
# when dtype is bf16, the processing flow is to dynamic quantize bf16 tensor to uint8 tensor
170+
# do w_kc (bf16) first to get the w_kc(uint8) w_s_kc(uint8)
171+
# and w_vc repeating the same procedure of w_kc to get w_vc(uint8) w_s_vc(uint8)
172+
if w.dtype == torch.bfloat16:
173+
w_kc, w_vc = w.unflatten(
174+
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
175+
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
176+
w_kc, w_s_kc = b_dynamic_mxfp4_quant(w_kc.transpose(-2, -1))
177+
w_kc = w_kc.transpose(-2, -1)
178+
w_s_kc = w_s_kc.transpose(-2, -1)
179+
w_vc, w_s_vc = b_dynamic_mxfp4_quant(w_vc)
180+
w_s_kc = w_s_kc.transpose(1, 2).contiguous().transpose(1, 2)
181+
w_s_vc = w_s_vc.contiguous().transpose(1, 2)
182+
elif w.dtype == torch.uint8: # static quant for mxfp4
183+
# when dtype is uint8, it means the w has been quantized to mxfp4 format
184+
# but we must separate it to w_kc and w_vc.
185+
# The quantized tensor size is only half of original tensor size
186+
# and the scaling factor is 1/32, the transpose behavior will be not correct
187+
# need to upcast it to fp32 to separate w to w_kc and w_vc
188+
# to ensure the following transpose behavior is correct
189+
# and then do mxfp4 quant again
190+
w = mxfp4_to_f32(w, True).to(torch.bfloat16)
191+
w_scales = self_attn.kv_b_proj.weight_scale.repeat_interleave(32, dim=-1)
192+
w_scales = e8m0_to_f32(w_scales).to(torch.bfloat16)
193+
w = w * w_scales
194+
w_kc, w_vc = w.unflatten(
195+
0, (-1, (self_attn.qk_nope_head_dim + self_attn.v_head_dim))
196+
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
197+
w_kc, w_s_kc = b_dynamic_mxfp4_quant(w_kc.transpose(-2, -1))
198+
w_kc = w_kc.transpose(-2, -1)
199+
w_s_kc = w_s_kc.transpose(-2, -1)
200+
w_vc, w_s_vc = b_dynamic_mxfp4_quant(w_vc)
201+
w_s_kc = w_s_kc.transpose(1, 2).contiguous().transpose(1, 2)
202+
w_s_vc = w_s_vc.contiguous().transpose(1, 2)
203+
204+
return w_kc, w_s_kc, w_vc, w_s_vc
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from aiter.ops.triton.batched_gemm_afp4wfp4_pre_quant import (
2+
batched_gemm_afp4wfp4_pre_quant,
3+
)
4+
from aiter.ops.triton.fused_mxfp4_quant import (
5+
fused_flatten_mxfp4_quant,
6+
fused_rms_mxfp4_quant,
7+
)
8+
9+
__all__ = [
10+
"fused_rms_mxfp4_quant",
11+
"fused_flatten_mxfp4_quant",
12+
"batched_gemm_afp4wfp4_pre_quant",
13+
]
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import torch
2+
from aiter.ops.triton.fused_qk_concat import fused_qk_rope_cat
3+
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
4+
from aiter.ops.triton.gemm_a16w16_atomic import gemm_a16w16_atomic
5+
6+
from sglang.srt.utils import BumpAllocator
7+
8+
__all__ = ["fused_qk_rope_cat"]
9+
10+
11+
def aiter_dsv3_router_gemm(
12+
hidden_states: torch.Tensor,
13+
weight: torch.Tensor,
14+
gemm_output_zero_allocator: BumpAllocator = None,
15+
):
16+
M = hidden_states.shape[0]
17+
N = weight.shape[0]
18+
y = None
19+
20+
if M <= 256:
21+
# TODO (cagri): convert to bfloat16 as part of another kernel to save time
22+
# for now it is also coupled with zero allocator.
23+
if gemm_output_zero_allocator != None:
24+
y = gemm_output_zero_allocator.allocate(M * N).view(M, N)
25+
else:
26+
y = torch.zeros((M, N), dtype=torch.float32, device=hidden_states.device)
27+
28+
if y is not None:
29+
logits = gemm_a16w16_atomic(hidden_states, weight, y=y).to(hidden_states.dtype)
30+
else:
31+
logits = gemm_a16w16(hidden_states, weight)
32+
33+
return logits
34+
35+
36+
def get_dsv3_gemm_output_zero_allocator_size(
37+
n_routed_experts: int, num_moe_layers: int, allocate_size: int, embedding_dim: int
38+
):
39+
if embedding_dim != 7168 or n_routed_experts != 256:
40+
return 0
41+
42+
per_layer_size = 256 * (allocate_size + n_routed_experts)
43+
44+
return num_moe_layers * per_layer_size

0 commit comments

Comments
 (0)