Skip to content

Commit 98d6c8f

Browse files
committed
Add mixed MoE block quant config and TPU coverage
1 parent 8536859 commit 98d6c8f

6 files changed

Lines changed: 226 additions & 9 deletions

File tree

python/sgl_jax/srt/kernels/quantized_matmul/kernel.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,21 @@ def _get_effective_block_sizes(
219219
return block_size_out, block_size_in
220220

221221

222+
def _should_use_3rd_party_blockwise_kernel(
223+
*,
224+
out_dim: int,
225+
block_size_out: int,
226+
) -> bool:
227+
"""Guard known-bad narrow-N TPU blockwise cases.
228+
229+
When a tensor-parallel column shard collapses to a single output block
230+
(for example local N=128 with block_size_out=128), the third-party TPU
231+
blockwise kernel can produce NaNs on Qwen3-MoE k/v projections. The local
232+
dequantized fallback remains numerically stable for the same inputs.
233+
"""
234+
return out_dim > block_size_out
235+
236+
222237
def _expand_block_scales_to_weight_shape(
223238
w_scale: jax.Array,
224239
out_dim: int,
@@ -307,7 +322,14 @@ def xla_quantized_matmul_local(
307322
# path as fallback for non-TPU / unavailable environments.
308323
out = None
309324
blockwise_3rd_kernel = _get_blockwise_3rd_kernel()
310-
if jax.default_backend() == "tpu" and blockwise_3rd_kernel is not None:
325+
if (
326+
jax.default_backend() == "tpu"
327+
and blockwise_3rd_kernel is not None
328+
and _should_use_3rd_party_blockwise_kernel(
329+
out_dim=int(out_dim),
330+
block_size_out=int(block_size_out),
331+
)
332+
):
311333
try:
312334
w_scale_3rd = _convert_block_scale_to_3rd_layout(
313335
w_scale=w_scale,
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# INT8 dynamic mixed quantization config
2+
# - Linear layers use per-channel quantization.
3+
# - MoE weights use 128x128 block quantization.
4+
5+
quantization:
6+
is_static_checkpoint: false
7+
weight_block_size: [128, 128]
8+
9+
linear:
10+
rules:
11+
- module_path: '.*'
12+
weight_dtype: 'int8'
13+
activation_dtype: null
14+
weight_block_size: null
15+
16+
moe:
17+
weight_dtype: 'int8'
18+
activation_dtype: null

python/sgl_jax/srt/utils/quantization/quantization_utils.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,28 @@ def _get_block_reshape_sharding(
3535
return NamedSharding(input_sharding.mesh, P(*blocked_spec))
3636

3737

38+
def _get_safe_block_quant_input_sharding(
39+
tensor: jax.Array,
40+
quantized_axes: list[int],
41+
) -> NamedSharding | None:
42+
"""Drop sharding on axes that are about to be split into (num_blocks, block)."""
43+
input_sharding = getattr(tensor, "sharding", None)
44+
if not isinstance(input_sharding, NamedSharding):
45+
return None
46+
47+
adjusted_spec = list(input_sharding.spec)
48+
changed = False
49+
for axis_idx in quantized_axes:
50+
if axis_idx < len(adjusted_spec) and adjusted_spec[axis_idx] is not None:
51+
adjusted_spec[axis_idx] = None
52+
changed = True
53+
54+
if not changed:
55+
return None
56+
57+
return NamedSharding(input_sharding.mesh, P(*adjusted_spec))
58+
59+
3860
def apply_linear_quantization(
3961
model_config: ModelConfig, model: nnx.Module, is_static_input: bool = False
4062
) -> nnx.Module:
@@ -76,6 +98,11 @@ def apply_linear_quantization(
7698
# Accept both sglang-jax style and Qwix-style field names.
7799
weight_dtype_str = rule.get("weight_dtype", rule.get("weight_qtype"))
78100
activation_dtype_str = rule.get("activation_dtype", rule.get("act_qtype"))
101+
weight_block_size = (
102+
rule["weight_block_size"]
103+
if "weight_block_size" in rule
104+
else getattr(quant_config, "weight_block_size", None)
105+
)
79106

80107
# Convert string dtypes to jnp dtypes
81108
weight_dtype = DTYPE_MAP.get(weight_dtype_str)
@@ -89,6 +116,7 @@ def apply_linear_quantization(
89116
"pattern": pattern,
90117
"weight_dtype": weight_dtype,
91118
"activation_dtype": activation_dtype,
119+
"weight_block_size": weight_block_size,
92120
}
93121
)
94122

@@ -140,7 +168,7 @@ def _replace_linear_recursive(obj, path: str = "", visited: set | None = None):
140168
weight_dtype=rule["weight_dtype"],
141169
activation_dtype=rule["activation_dtype"],
142170
is_static_input=is_static_input,
143-
weight_block_size=getattr(quant_config, "weight_block_size", None),
171+
weight_block_size=rule["weight_block_size"],
144172
)
145173
# Replace the attribute and free old weights
146174
setattr(obj, attr_name, quantized_linear)
@@ -323,6 +351,7 @@ def quantize_tensor(
323351
axis = [axis]
324352

325353
orig_shape = tensor.shape
354+
original_input_sharding = getattr(tensor, "sharding", None)
326355
mask = None
327356

328357
if block_size is not None:
@@ -356,14 +385,19 @@ def quantize_tensor(
356385

357386
orig_shape = tensor.shape
358387
# Convert all axis into positive values.
359-
axis = sorted([i % tensor.ndim for i in axis])
388+
quantized_axes = sorted([i % tensor.ndim for i in axis])
389+
safe_input_sharding = _get_safe_block_quant_input_sharding(tensor, quantized_axes)
390+
if safe_input_sharding is not None:
391+
tensor = jax.sharding.reshard(tensor, safe_input_sharding)
392+
if mask is not None:
393+
mask = jax.sharding.reshard(mask, safe_input_sharding)
394+
360395
# Shift axis by 1 since its original position is now occupied by
361396
# num_blocks dim. Also, if n axes before an axis was also quantized,
362397
# shift its position by n.
363-
axis = [1 + n + i for n, i in enumerate(axis)]
398+
axis = [1 + n + i for n, i in enumerate(quantized_axes)]
364399

365-
input_sharding = getattr(tensor, "sharding", None)
366-
blocked_out_sharding = _get_block_reshape_sharding(tensor, axis)
400+
blocked_out_sharding = _get_block_reshape_sharding(tensor, quantized_axes)
367401

368402
# Flatten list of lists that contains (num_blocks, block).
369403
blocked_shape = list(itertools.chain(*blocked_shape))
@@ -383,8 +417,8 @@ def quantize_tensor(
383417
# Guard all-zero blocks/tensors: scale==0 would produce 0/0 -> NaN.
384418
scale_safe = scale + (scale == 0).astype(scale.dtype)
385419
tensor_q = jnp.clip(tensor / scale_safe, dtype_min, dtype_max)
386-
if block_size is not None and isinstance(input_sharding, NamedSharding):
387-
tensor_q = jax.lax.reshape(tensor_q, orig_shape, out_sharding=input_sharding)
420+
if block_size is not None and isinstance(original_input_sharding, NamedSharding):
421+
tensor_q = jax.lax.reshape(tensor_q, orig_shape, out_sharding=original_input_sharding)
388422
else:
389423
tensor_q = tensor_q.reshape(orig_shape)
390424
tensor_q = tensor_q.astype(dtype)

python/sgl_jax/test/kernels/quantized_linear_test.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
import sgl_jax.srt.kernels.quantized_matmul.kernel as quant_kernel
1111
from sgl_jax.srt.kernels.quantized_matmul.kernel import xla_quantized_matmul_local
12-
from sgl_jax.srt.layers.linear import QuantizedLinear
12+
from sgl_jax.srt.layers.linear import LinearBase, QuantizedLinear
13+
from sgl_jax.srt.utils.quantization.quantization_utils import apply_linear_quantization
1314
from sgl_jax.srt.utils.quantization.quantization_utils import quantize_tensor
1415

1516

@@ -183,8 +184,55 @@ def test_blockwise_tuning_fallback_uses_compatible_seed(monkeypatch):
183184
_assert_blockwise_tuning_fallback_uses_compatible_seed()
184185

185186

187+
def test_linear_rule_weight_block_size_override():
188+
class DummyModel(nnx.Module):
189+
def __init__(self, mesh):
190+
self.proj = LinearBase(
191+
input_size=256,
192+
output_size=512,
193+
use_bias=False,
194+
mesh=mesh,
195+
kernel_axes=(None, None),
196+
params_dtype=jnp.bfloat16,
197+
scope_name="proj",
198+
)
199+
200+
mesh = _create_single_device_mesh()
201+
model = DummyModel(mesh)
202+
203+
class FakeModelConfig:
204+
pass
205+
206+
model_config = FakeModelConfig()
207+
model_config.quantization_config = type(
208+
"FakeQuantConfig",
209+
(),
210+
{
211+
"get_linear_rules": staticmethod(
212+
lambda: [
213+
{
214+
"module_path": ".*",
215+
"weight_dtype": "int8",
216+
"activation_dtype": None,
217+
"weight_block_size": None,
218+
}
219+
]
220+
),
221+
"ignored_layers": None,
222+
"weight_block_size": [128, 128],
223+
},
224+
)()
225+
226+
apply_linear_quantization(model_config, model, is_static_input=False)
227+
228+
assert isinstance(model.proj, QuantizedLinear)
229+
assert model.proj.weight_block_size is None
230+
assert model.proj.weight_scale.value.ndim == 1
231+
232+
186233
if __name__ == "__main__":
187234
for fmt in ("per_channel", "block_channel", "block_quant"):
188235
test_quantized_linear_offline_scale_formats(fmt)
189236
test_xla_quantized_matmul_block_quant_all()
190237
_assert_blockwise_tuning_fallback_uses_compatible_seed()
238+
test_linear_rule_weight_block_size_override()
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import os
2+
import re
3+
import sys
4+
import time
5+
import unittest
6+
7+
import requests
8+
9+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
10+
11+
from sgl_jax.srt.utils import kill_process_tree
12+
from sgl_jax.test.test_utils import DEFAULT_URL_FOR_TEST, CustomTestCase, popen_launch_server
13+
14+
15+
class TestW8Int8MoeBlockLinearChannelQuant(CustomTestCase):
16+
model = "Qwen/Qwen3-30B-A3B"
17+
quantization_config_path = "int8_moe_block_128_linear_channel_dynamic.yaml"
18+
other_args = [
19+
"--tp-size=4",
20+
"--ep-size=4",
21+
"--download-dir=/dev/shm",
22+
"--max-running-requests=64",
23+
"--page-size=64",
24+
"--disable-precompile",
25+
]
26+
27+
@classmethod
28+
def setUpClass(cls):
29+
cls.base_url = DEFAULT_URL_FOR_TEST
30+
other_args = [
31+
"--quantization-config-path",
32+
cls.quantization_config_path,
33+
*cls.other_args,
34+
]
35+
cls.process = popen_launch_server(
36+
cls.model,
37+
cls.base_url,
38+
timeout=1800,
39+
other_args=other_args,
40+
check_cache_miss=False,
41+
)
42+
43+
@classmethod
44+
def tearDownClass(cls):
45+
kill_process_tree(cls.process.pid)
46+
try:
47+
cls.process.wait(timeout=30)
48+
except Exception:
49+
pass
50+
time.sleep(5)
51+
52+
def _generate(self, prompt, max_new_tokens=16):
53+
response = requests.post(
54+
self.base_url + "/generate",
55+
json={
56+
"text": prompt,
57+
"sampling_params": {
58+
"temperature": 0,
59+
"max_new_tokens": max_new_tokens,
60+
},
61+
},
62+
timeout=600,
63+
)
64+
response.raise_for_status()
65+
return response.json()
66+
67+
def test_basic_generation(self):
68+
prompts = [
69+
(
70+
"Answer with one word only. What is the capital of France?",
71+
re.compile(r"\bparis\b"),
72+
),
73+
(
74+
"Answer with one number only. What is 12 + 7?",
75+
re.compile(r"\b19\b"),
76+
),
77+
]
78+
79+
for prompt, expected_pattern in prompts:
80+
data = self._generate(prompt, max_new_tokens=8)
81+
text = data.get("text", "")
82+
self.assertTrue(text.strip(), f"Empty generation response: {data}")
83+
self.assertRegex(
84+
text.lower(),
85+
expected_pattern,
86+
msg=f"Unexpected generation text for prompt {prompt!r}: {text!r}",
87+
)
88+
89+
90+
if __name__ == "__main__":
91+
unittest.main()

test/srt/run_suite.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,10 @@ def run_one_file(filename):
486486
TestFile("test/srt/test_model_loader.py", 5),
487487
TestFile("test/srt/quantization/test_w8_quantization.py", 10),
488488
TestFile("test/srt/quantization/test_w8_block_dynamic_quantization.py", 8, runner="pytest"),
489+
TestFile(
490+
"test/srt/quantization/test_w8_moe_block_linear_channel_quantization.py",
491+
15,
492+
),
489493
TestFile("test/srt/test_engine_determine_generation.py", 5),
490494
TestFile("test/srt/test_engine_flush_cache.py", 5),
491495
TestFile("test/srt/test_engine_pause_continue.py", 6),

0 commit comments

Comments
 (0)