Skip to content

Commit 05dea53

Browse files
committed
fix pr error
1 parent 98d6c8f commit 05dea53

5 files changed

Lines changed: 143 additions & 98 deletions

File tree

python/sgl_jax/srt/kernels/fused_moe/v1/kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3086,4 +3086,4 @@ def kernel(
30863086
w1_shared_scale,
30873087
w3_shared_scale,
30883088
w2_shared_scale,
3089-
)
3089+
)

python/sgl_jax/srt/layers/linear.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def __init__(
5959
else:
6060
self.bias = None
6161

62-
def __call__(self, x: jax.Array) -> jax.Array | tuple[jax.Array, jax.Array]:
62+
def __call__(self, x: jax.Array) -> tuple[jax.Array, jax.Array | None]:
6363
"""Forward pass."""
6464
x_2d = x.reshape(-1, x.shape[-1]) if x.ndim > 2 else x
6565

@@ -88,7 +88,7 @@ def _sharded_dot(lhs: jax.Array, rhs: jax.Array) -> jax.Array:
8888
if self.skip_bias_add:
8989
return out, (self.bias.value if self.bias is not None else None)
9090
if self.bias is not None:
91-
return out + self.bias.value
91+
out = out + self.bias.value
9292
return out, None
9393

9494

@@ -176,7 +176,7 @@ def from_linear(
176176
bias = linear.bias.value if linear.bias is not None else None
177177
else:
178178
weight = linear.weight.value
179-
weight_t = weight.T
179+
weight_t = weight.T
180180

181181
if effective_weight_block_size is not None and len(effective_weight_block_size) == 2:
182182
weight_q, weight_scale = quantize_tensor(
@@ -197,7 +197,7 @@ def from_linear(
197197
scope_name=f"quantized_{linear.name}",
198198
)
199199

200-
def __call__(self, x: jax.Array) -> jax.Array | tuple[jax.Array, jax.Array]:
200+
def __call__(self, x: jax.Array) -> tuple[jax.Array, jax.Array | None]:
201201
"""Forward pass with quantization."""
202202
quantize_activation = self.activation_dtype is not None
203203
x_2d = x.reshape(-1, x.shape[-1]) if x.ndim > 2 else x
@@ -208,7 +208,7 @@ def __call__(self, x: jax.Array) -> jax.Array | tuple[jax.Array, jax.Array]:
208208

209209
input_axis, output_axis = self.kernel_axes[0], self.kernel_axes[1]
210210
w_scale_spec = P(output_axis) if scale_val.ndim == 1 else P(output_axis, input_axis)
211-
211+
212212
in_specs = (P(None, input_axis), P(output_axis, input_axis), w_scale_spec)
213213
out_specs = P(None, output_axis)
214214

@@ -239,5 +239,5 @@ def __call__(self, x: jax.Array) -> jax.Array | tuple[jax.Array, jax.Array]:
239239
if self.skip_bias_add:
240240
return output, (self.bias.value if self.bias is not None else None)
241241
if self.bias is not None:
242-
return output + self.bias.value
243-
return output
242+
output = output + self.bias.value
243+
return output, None

python/sgl_jax/srt/utils/quantization/quantization_utils.py

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,6 @@ def _replace_linear_recursive(obj, path: str = "", visited: set | None = None):
150150
if any(dot_path.endswith(ignored) or ignored in dot_path for ignored in ignored_layers):
151151
logger.info("Skipping %s - in ignored_layers", child_path)
152152
continue
153-
if "self_attn.o_proj" in dot_path and ignored_layers:
154-
logger.info("Skipping %s - explicit o_proj ignore", child_path)
155-
continue
156153

157154
rule = _find_matching_rule(child_path)
158155
if rule is not None:
@@ -247,63 +244,6 @@ def _quantize_moe_recursive(obj, path: str = "", visited=None):
247244
return model
248245

249246

250-
def adapt_fused_moe_static_block_quant_for_kernel(
251-
model: nnx.Module,
252-
*,
253-
target_subc_quant_wsz: int = 256,
254-
) -> nnx.Module:
255-
"""Adapt static fused-MoE block quant weights/scales before fused kernel execution.
256-
257-
This is a front-end compatibility step for static checkpoints whose fused MoE
258-
subchannel block size is smaller than the fused kernel's supported size.
259-
"""
260-
# Import here to avoid circular imports
261-
from sgl_jax.srt.layers.moe import FusedEPMoE
262-
263-
adapted_count = 0
264-
265-
def _adapt_recursive(obj, path: str = "", visited=None):
266-
nonlocal adapted_count
267-
if visited is None:
268-
visited = set()
269-
270-
obj_id = id(obj)
271-
if obj_id in visited:
272-
return
273-
visited.add(obj_id)
274-
275-
if isinstance(obj, FusedEPMoE):
276-
if obj.prepare_static_block_quant_for_fused_kernel(
277-
target_subc_quant_wsz=target_subc_quant_wsz
278-
):
279-
adapted_count += 1
280-
logger.info(
281-
"Adapted static fused MoE at %s to subc=%s for fused kernel",
282-
path or getattr(obj, "name", type(obj).__name__),
283-
target_subc_quant_wsz,
284-
)
285-
return
286-
287-
if hasattr(obj, "__dict__"):
288-
for attr_name, attr_value in obj.__dict__.items():
289-
child_path = f"{path}/{attr_name}" if path else attr_name
290-
if isinstance(attr_value, nnx.Module):
291-
_adapt_recursive(attr_value, child_path, visited)
292-
elif isinstance(attr_value, list):
293-
for idx, item in enumerate(attr_value):
294-
if isinstance(item, nnx.Module):
295-
item_path = f"{child_path}[{idx}]"
296-
_adapt_recursive(item, item_path, visited)
297-
298-
_adapt_recursive(model)
299-
if adapted_count:
300-
logger.info(
301-
"Completed static fused MoE block-quant kernel adaptation on %d layer(s)",
302-
adapted_count,
303-
)
304-
return model
305-
306-
307247
def quantize_tensor_simple(
308248
x: jax.Array, dtype: jnp.dtype, dim: int = -1, out_dtype: jnp.dtype = jnp.float32
309249
):

python/sgl_jax/test/kernels/moe_block_quant_test.py

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -22,40 +22,39 @@ def test_epmoe_block_quant_logic(weight_block_size, expected_k_blocks_wi, expect
2222
Focuses on weight/scale shapes and placeholder generation.
2323
"""
2424
print("\n>>> Testing EP MoE Block Quantization Logic (CPU) <<<")
25-
25+
2626
# 1. Setup a minimal mesh
2727
devices = jax.devices()
2828
# Use names that are safe for CPU/Standard JAX
2929
mesh = Mesh(np.array(devices[:1]).reshape(1, 1), axis_names=("data", "tensor"))
30-
30+
3131
# 2. Configuration
3232
hidden_size = 512
3333
intermediate_dim = 1024
3434
num_experts = 4
3535
num_experts_per_tok = 1
36-
block_size = 128
37-
36+
3837
# Mock QuantizationConfig
3938
class MockQuantConfig:
40-
def get_moe_weight_dtype(self): return jnp.int8
41-
def get_moe_activation_dtype(self): return None
39+
def get_moe_weight_dtype(self):
40+
return jnp.int8
41+
42+
def get_moe_activation_dtype(self):
43+
return None
44+
4245
@property
43-
def weight_block_size(self): return weight_block_size
44-
46+
def weight_block_size(self):
47+
return weight_block_size
48+
4549
quant_config = MockQuantConfig()
46-
50+
4751
# 3. Initialize EPMoE with Mocked Mesh to bypass sharding checks on CPU
48-
from unittest.mock import MagicMock
49-
mock_mesh = MagicMock(spec=Mesh)
50-
mock_mesh.shape = {"expert": 1, "tensor": 1}
51-
mock_mesh.axis_names = ("expert", "tensor")
52-
mock_mesh.devices = np.array(jax.devices()[:1]).reshape(1, 1)
53-
5452
# We monkeypatch the sharding in EPMoE to be CPU-friendly for this test
5553
original_p = P
5654
import sgl_jax.srt.layers.moe as moe_module
57-
moe_module.P = lambda *args: None # Disable sharding for CPU UT
58-
55+
56+
moe_module.P = lambda *args: None # Disable sharding for CPU UT
57+
5958
try:
6059
moe = EPMoE(
6160
hidden_size=hidden_size,
@@ -67,36 +66,40 @@ def weight_block_size(self): return weight_block_size
6766
quantization_config=quant_config,
6867
)
6968
finally:
70-
moe_module.P = original_p # Restore
69+
moe_module.P = original_p # Restore
7170

72-
7371
# 4. Run Quantization Prep
7472
moe.quantize_weights(is_static=True)
75-
73+
7674
# 5. Assert Scale Shapes
7775
# EPMoE logic: k_blocks = hidden_size // block_size
7876
k_blocks_wi = expected_k_blocks_wi
7977
k_blocks_wo = expected_k_blocks_wo
80-
78+
8179
print(f" Expert Count: {num_experts}")
8280
print(f" K Blocks (WI): {k_blocks_wi}")
8381
print(f" K Blocks (WO): {k_blocks_wo}")
84-
82+
8583
expected_wi_shape = (num_experts, k_blocks_wi, 1, intermediate_dim)
8684
expected_wo_shape = (num_experts, k_blocks_wo, 1, hidden_size)
87-
85+
8886
print(f" WI_0 Scale Shape: {moe.wi_0_scale.value.shape}")
8987
print(f" WO Scale Shape: {moe.wo_scale.value.shape}")
90-
91-
assert moe.wi_0_scale.value.shape == expected_wi_shape, f"WI shape mismatch: {moe.wi_0_scale.value.shape} vs {expected_wi_shape}"
92-
assert moe.wo_scale.value.shape == expected_wo_shape, f"WO shape mismatch: {moe.wo_scale.value.shape} vs {expected_wo_shape}"
93-
88+
89+
assert (
90+
moe.wi_0_scale.value.shape == expected_wi_shape
91+
), f"WI shape mismatch: {moe.wi_0_scale.value.shape} vs {expected_wi_shape}"
92+
assert (
93+
moe.wo_scale.value.shape == expected_wo_shape
94+
), f"WO shape mismatch: {moe.wo_scale.value.shape} vs {expected_wo_shape}"
95+
9496
print(" Shape Verification: PASSED")
95-
97+
9698
# 6. Verify Content (Should be zeros as initialized in is_static=True)
9799
assert jnp.all(moe.wi_0_scale.value == 0)
98100
print(" Content Verification: PASSED")
99101

102+
100103
if __name__ == "__main__":
101104
try:
102105
test_epmoe_block_quant_logic(None, 1, 1)

python/sgl_jax/test/kernels/quantized_linear_test.py

Lines changed: 104 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import jax.numpy as jnp
66
import numpy as np
77
import pytest
8+
from flax import nnx
89
from jax.sharding import Mesh
910

1011
import sgl_jax.srt.kernels.quantized_matmul.kernel as quant_kernel
@@ -101,7 +102,8 @@ def test_quantized_linear_offline_scale_formats(scale_format):
101102
)
102103

103104
ref_out = jnp.dot(x, w_fp.T)
104-
out = quant_linear(x)
105+
out, bias = quant_linear(x)
106+
assert bias is None
105107

106108
_assert_close(f"Offline QuantizedLinear ({scale_format})", out, ref_out)
107109

@@ -198,7 +200,8 @@ def __init__(self, mesh):
198200
)
199201

200202
mesh = _create_single_device_mesh()
201-
model = DummyModel(mesh)
203+
with jax.set_mesh(mesh):
204+
model = DummyModel(mesh)
202205

203206
class FakeModelConfig:
204207
pass
@@ -230,9 +233,108 @@ class FakeModelConfig:
230233
assert model.proj.weight_scale.value.ndim == 1
231234

232235

236+
def test_linear_return_contract_with_bias():
237+
mesh = _create_single_device_mesh()
238+
x = jnp.ones((2, 64), dtype=jnp.bfloat16)
239+
240+
with jax.set_mesh(mesh):
241+
linear = LinearBase(
242+
input_size=64,
243+
output_size=32,
244+
use_bias=True,
245+
mesh=mesh,
246+
kernel_axes=(None, None),
247+
params_dtype=jnp.bfloat16,
248+
scope_name="biased_proj",
249+
)
250+
out, bias = linear(x)
251+
252+
assert out.shape == (2, 32)
253+
assert bias is None
254+
255+
with jax.set_mesh(mesh):
256+
quant_linear = QuantizedLinear.from_linear(
257+
linear,
258+
weight_dtype=jnp.int8,
259+
activation_dtype=None,
260+
is_static_input=False,
261+
)
262+
q_out, q_bias = quant_linear(x)
263+
264+
assert q_out.shape == (2, 32)
265+
assert q_bias is None
266+
267+
268+
def test_ignored_layers_only_skips_requested_paths():
269+
class SelfAttn(nnx.Module):
270+
def __init__(self, mesh):
271+
self.q_proj = LinearBase(
272+
input_size=64,
273+
output_size=32,
274+
use_bias=False,
275+
mesh=mesh,
276+
kernel_axes=(None, None),
277+
params_dtype=jnp.bfloat16,
278+
scope_name="q_proj",
279+
)
280+
self.o_proj = LinearBase(
281+
input_size=64,
282+
output_size=32,
283+
use_bias=False,
284+
mesh=mesh,
285+
kernel_axes=(None, None),
286+
params_dtype=jnp.bfloat16,
287+
scope_name="o_proj",
288+
)
289+
290+
class DummyBlock(nnx.Module):
291+
def __init__(self, mesh):
292+
self.self_attn = SelfAttn(mesh)
293+
294+
class FakeModelConfig:
295+
pass
296+
297+
def _make_config(ignored_layers):
298+
model_config = FakeModelConfig()
299+
model_config.quantization_config = type(
300+
"FakeQuantConfig",
301+
(),
302+
{
303+
"get_linear_rules": staticmethod(
304+
lambda: [
305+
{
306+
"module_path": ".*",
307+
"weight_dtype": "int8",
308+
"activation_dtype": None,
309+
"weight_block_size": None,
310+
}
311+
]
312+
),
313+
"ignored_layers": ignored_layers,
314+
"weight_block_size": [128, 128],
315+
},
316+
)()
317+
return model_config
318+
319+
mesh = _create_single_device_mesh()
320+
with jax.set_mesh(mesh):
321+
model = DummyBlock(mesh)
322+
apply_linear_quantization(_make_config(["some_other_layer"]), model, is_static_input=False)
323+
assert isinstance(model.self_attn.q_proj, QuantizedLinear)
324+
assert isinstance(model.self_attn.o_proj, QuantizedLinear)
325+
326+
with jax.set_mesh(mesh):
327+
model = DummyBlock(mesh)
328+
apply_linear_quantization(_make_config(["self_attn.o_proj"]), model, is_static_input=False)
329+
assert isinstance(model.self_attn.q_proj, QuantizedLinear)
330+
assert isinstance(model.self_attn.o_proj, LinearBase)
331+
332+
233333
if __name__ == "__main__":
234334
for fmt in ("per_channel", "block_channel", "block_quant"):
235335
test_quantized_linear_offline_scale_formats(fmt)
236336
test_xla_quantized_matmul_block_quant_all()
237337
_assert_blockwise_tuning_fallback_uses_compatible_seed()
238338
test_linear_rule_weight_block_size_override()
339+
test_linear_return_contract_with_bias()
340+
test_ignored_layers_only_skips_requested_paths()

0 commit comments

Comments
 (0)