Skip to content

Commit 7f076dc

Browse files
committed
fix according to review comments
1 parent 05dea53 commit 7f076dc

11 files changed

Lines changed: 348 additions & 72 deletions

File tree

python/sgl_jax/srt/configs/quantization_config.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import os
99
from dataclasses import dataclass
10+
from numbers import Integral
1011

1112
import jax.numpy as jnp
1213
import yaml
@@ -55,6 +56,32 @@ def _resolve_config_path(config_path: str) -> str:
5556
)
5657

5758

59+
def _normalize_weight_block_size(
60+
weight_block_size: list[int] | tuple[int, int] | None,
61+
) -> tuple[int, int] | None:
62+
if weight_block_size is None:
63+
return None
64+
if not isinstance(weight_block_size, (list, tuple)) or len(weight_block_size) != 2:
65+
raise ValueError(
66+
"quantization.weight_block_size must be a 2-element list/tuple "
67+
f"[block_n, block_k], got {weight_block_size!r}"
68+
)
69+
block_n, block_k = weight_block_size
70+
if not isinstance(block_n, Integral) or not isinstance(block_k, Integral):
71+
raise ValueError(
72+
"quantization.weight_block_size values must be integers, "
73+
f"got {weight_block_size!r}"
74+
)
75+
block_n = int(block_n)
76+
block_k = int(block_k)
77+
if block_n <= 0 or block_k <= 0:
78+
raise ValueError(
79+
"quantization.weight_block_size values must be > 0, "
80+
f"got {weight_block_size!r}"
81+
)
82+
return (block_n, block_k)
83+
84+
5885
@dataclass
5986
class QuantizationConfig:
6087
"""Quantization configuration with explicit settings (no fallbacks).
@@ -65,15 +92,15 @@ class QuantizationConfig:
6592
moe_activation_dtype: Dtype for MoE activation quantization (None = no quantization)
6693
is_static_checkpoint: Whether the checkpoint is static (true for checkpoints quantized offline, false for on-the-fly quantization)
6794
ignored_layers: Optional list of layer name patterns to exclude from quantization
68-
weight_block_size: Optional block sizes for static checkpoints (e.g., [128, 128])
95+
weight_block_size: Optional block sizes for block quantization (e.g., [128, 128])
6996
"""
7097

7198
linear_rules: list[dict] | None = None
7299
moe_weight_dtype: jnp.dtype | None = None
73100
moe_activation_dtype: jnp.dtype | None = None
74101
is_static_checkpoint: bool = False
75102
ignored_layers: list[str] | None = None
76-
weight_block_size: list[int] | None = None
103+
weight_block_size: tuple[int, int] | None = None
77104

78105
@classmethod
79106
def from_yaml(cls, yaml_path: str) -> "QuantizationConfig":
@@ -127,7 +154,7 @@ def from_yaml(cls, yaml_path: str) -> "QuantizationConfig":
127154
moe_weight_dtype = _str_to_dtype(moe_section.get("weight_dtype"))
128155
moe_activation_dtype = _str_to_dtype(moe_section.get("activation_dtype"))
129156
is_static_checkpoint = quant.get("is_static_checkpoint", False)
130-
weight_block_size = quant.get("weight_block_size")
157+
weight_block_size = _normalize_weight_block_size(quant.get("weight_block_size"))
131158

132159
return cls(
133160
linear_rules=linear_rules,

python/sgl_jax/srt/kernels/quantized_matmul/3rd_quantized_matmul/blockwise_kernel.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99

1010
from . import util
1111
from .tuned_block_sizes import (
12-
TunedValue, get_device_vmem_limit, get_tuned_block_sizes)
13-
from .util import (get_kernel_name,
14-
next_multiple,
15-
unfold_args)
12+
TunedValue,
13+
get_device_vmem_limit,
14+
get_tuned_block_sizes,
15+
)
16+
from .util import get_kernel_name, next_multiple, unfold_args
1617

1718
quantize_tensor = util.quantize_tensor
1819
MXU_SIZE = 256
@@ -215,7 +216,7 @@ def accum(is_first_step, is_last_step):
215216
out_specs=pl.BlockSpec((batch_block_size, out_block_size),
216217
lambda b, o, i: (b, o)),
217218
scratch_shapes=[
218-
pltpu.VMEM((batch_block_size, out_block_size), jnp.bfloat16)
219+
pltpu.VMEM((batch_block_size, out_block_size), acc_dtype)
219220
],
220221
grid=(n_batch, n_out, n_in),
221222
),

python/sgl_jax/srt/kernels/quantized_matmul/3rd_quantized_matmul/util.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@
55

66
import jax
77
import jax.numpy as jnp
8-
from jax._src import dtypes
98

109
from .tuned_block_sizes import TunedValue
1110

1211

12+
def _dtype_bits(dtype: jnp.dtype) -> int:
13+
return jnp.dtype(dtype).itemsize * 8
14+
15+
1316
def unfold_args(
1417
conditions: tuple[jax.Array | bool, ...],
1518
fn_conditions: tuple[bool, ...],
@@ -191,7 +194,8 @@ def quantize_array(
191194

192195
# TODO(kyuyeunk): Investigate performance gain from non xlu transpose.
193196
scale = jnp.transpose(x_abs_max / dtype_max)
194-
scale_inv = jnp.nan_to_num(1 / scale, dtype_max)
197+
scale = jnp.where(scale == 0, 1.0, scale)
198+
scale_inv = jnp.nan_to_num(1 / scale, nan=dtype_max, posinf=dtype_max, neginf=-dtype_max)
195199
return (x * scale_inv).astype(quant_dtype), scale.astype(jnp.float32)
196200

197201

@@ -215,13 +219,11 @@ def get_vmem_limit(
215219
"""Calculate VMEM limit for the kernel."""
216220

217221
# Calculate in/out VMEM size.
218-
x_size = (batch_block_size * in_block_size * dtypes.itemsize_bits(x_dtype))
219-
x_abs_max_size = batch_block_size * dtypes.itemsize_bits(scale_dtype)
220-
w_q_size = (out_block_size * in_block_size *
221-
dtypes.itemsize_bits(w_q_dtype))
222-
w_scale_size = out_block_size * dtypes.itemsize_bits(scale_dtype)
223-
out_size = (batch_block_size * out_block_size *
224-
dtypes.itemsize_bits(out_dtype))
222+
x_size = batch_block_size * in_block_size * _dtype_bits(x_dtype)
223+
x_abs_max_size = batch_block_size * _dtype_bits(scale_dtype)
224+
w_q_size = out_block_size * in_block_size * _dtype_bits(w_q_dtype)
225+
w_scale_size = out_block_size * _dtype_bits(scale_dtype)
226+
out_size = batch_block_size * out_block_size * _dtype_bits(out_dtype)
225227

226228
vmem_in_out = x_size + x_abs_max_size + w_q_size + w_scale_size + out_size
227229
vmem_in_out *= 2 # Account for compute and vreg spills.
@@ -235,11 +237,9 @@ def get_vmem_limit(
235237
vmem_in_out += out_size if (n_batch > 1 or n_out > 1) else 0
236238

237239
# Calculate scratch VMEM size.
238-
acc_size = (batch_block_size * out_block_size *
239-
dtypes.itemsize_bits(acc_dtype))
240-
x_q_size = (batch_block_size * in_block_size *
241-
dtypes.itemsize_bits(x_q_dtype))
242-
x_scale_size = batch_block_size * dtypes.itemsize_bits(scale_dtype)
240+
acc_size = batch_block_size * out_block_size * _dtype_bits(acc_dtype)
241+
x_q_size = batch_block_size * in_block_size * _dtype_bits(x_q_dtype)
242+
x_scale_size = batch_block_size * _dtype_bits(scale_dtype)
243243

244244
vmem_scratch = acc_size if save_acc else 0
245245
vmem_scratch += x_q_size + x_scale_size if save_x_q else 0
@@ -277,10 +277,14 @@ def validate_inputs(
277277
# Verify input shapes.
278278
if x.shape[1] != w_q.shape[1]:
279279
raise ValueError(f'{x.shape[1]=} must be equal to {w_q.shape[1]=}')
280-
if w_q.shape[0] != w_scale.shape[1] and (w_scale.ndim == 3 and w_q.shape[0]
281-
!= w_scale.shape[2]):
282-
raise ValueError(
283-
f"{w_q.shape[0]=} must be equal to {w_scale.shape[1]=}")
280+
if w_scale.ndim == 2:
281+
if w_q.shape[0] != w_scale.shape[1]:
282+
raise ValueError(f"{w_q.shape[0]=} must be equal to {w_scale.shape[1]=}")
283+
elif w_scale.ndim == 3:
284+
if w_q.shape[0] != w_scale.shape[2]:
285+
raise ValueError(f"{w_q.shape[0]=} must be equal to {w_scale.shape[2]=}")
286+
else:
287+
raise ValueError(f"Unsupported {w_scale.ndim=} for quantized weight scale.")
284288
if x_abs_max is not None and x_abs_max.shape != (1, x.shape[0]):
285289
raise ValueError(
286290
f"{x_abs_max.shape=} must be equal to (1, {x.shape[0]=})")
@@ -317,5 +321,5 @@ def quantize_block(data, axis, target_dtype):
317321
if jnp.issubdtype(target_dtype, jnp.floating):
318322
data_q = (data / scale).clip(dtype_min, dtype_max).astype(target_dtype)
319323
else:
320-
data_q = jnp.round(data / scale).astype(target_dtype)
324+
data_q = jnp.clip(jnp.round(data / scale), dtype_min, dtype_max).astype(target_dtype)
321325
return data_q, scale

python/sgl_jax/srt/kernels/quantized_matmul/kernel.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
11
# SPDX-License-Identifier: Apache-2.0
22
"""Quantized matmul kernel."""
33

4+
import functools
45
import importlib
6+
import logging
57
import math
8+
import re
69

710
import jax
811
import jax.numpy as jnp
912
from jax import lax
1013

1114
from sgl_jax.srt.utils.quantization.quantization_utils import quantize_tensor_simple
1215

16+
logger = logging.getLogger(__name__)
17+
1318

1419
_BLOCKWISE_3RD_KERNEL = None
1520
_TRIED_LOADING_BLOCKWISE_3RD_KERNEL = False
@@ -32,6 +37,7 @@ def _get_blockwise_3rd_kernel():
3237
module = importlib.import_module(f"{package}.3rd_quantized_matmul")
3338
_BLOCKWISE_3RD_KERNEL = getattr(module, "quantized_matmul", None)
3439
except Exception:
40+
logger.debug("Failed to import third-party blockwise quantized matmul kernel.", exc_info=True)
3541
_BLOCKWISE_3RD_KERNEL = None
3642
return _BLOCKWISE_3RD_KERNEL
3743

@@ -58,6 +64,7 @@ def _get_blockwise_3rd_tuning_api():
5864
_BLOCKWISE_3RD_GET_TUNED_BLOCK_SIZES = getattr(module, "get_tuned_block_sizes", None)
5965
_BLOCKWISE_3RD_TUNED_BLOCK_SIZES = getattr(module, "TUNED_BLOCK_SIZES", None)
6066
except Exception:
67+
logger.debug("Failed to import third-party blockwise tuning metadata.", exc_info=True)
6168
_BLOCKWISE_3RD_TUNED_VALUE_CLS = None
6269
_BLOCKWISE_3RD_GET_TUNED_BLOCK_SIZES = None
6370
_BLOCKWISE_3RD_TUNED_BLOCK_SIZES = None
@@ -102,13 +109,26 @@ def _candidate(units_value: int) -> int:
102109
return min(candidates, key=lambda value: (abs(value - x), -value))
103110

104111

112+
@functools.lru_cache(maxsize=1)
113+
def _get_current_tpu_version() -> int:
114+
try:
115+
kind = jax.devices()[0].device_kind
116+
except Exception:
117+
return -1
118+
match = re.match(r"^TPU[^\d]*(\d+)", kind)
119+
if match is None:
120+
return -1
121+
return int(match.group(1))
122+
123+
105124
def _iter_blockwise_tuned_candidates(
106125
tuned_block_sizes: dict | None,
107126
n_batch: int,
108127
n_out: int,
109128
n_in: int,
110129
x_q_dtype: jnp.dtype,
111130
w_q_dtype: jnp.dtype,
131+
tpu_version: int,
112132
):
113133
if not tuned_block_sizes:
114134
return []
@@ -121,6 +141,8 @@ def _iter_blockwise_tuned_candidates(
121141

122142
candidates = []
123143
for key, value in tuned_block_sizes.items():
144+
if getattr(key, "tpu_version", tpu_version) != tpu_version:
145+
continue
124146
if key.w_q_dtype != w_q_dtype_name:
125147
continue
126148
if key.x_q_dtype not in compatible_x_dtype_names:
@@ -162,6 +184,7 @@ def _get_safe_blockwise_tuned_value(
162184
n_in=n_in,
163185
x_q_dtype=x_q_dtype,
164186
w_q_dtype=w_q_dtype,
187+
tpu_version=_get_current_tpu_version(),
165188
)
166189
if compatible_candidates:
167190
tuned = compatible_candidates[0]
@@ -175,6 +198,7 @@ def _get_safe_blockwise_tuned_value(
175198
w_q_dtype=jnp.dtype(w_q_dtype).name,
176199
)
177200
except Exception:
201+
logger.debug("Failed to query tuned block sizes from third-party kernel.", exc_info=True)
178202
tuned = None
179203
if tuned is None:
180204
tuned = tuned_value_cls(128, 128, 128, 1)
@@ -356,6 +380,7 @@ def xla_quantized_matmul_local(
356380
tuned_value=tuned_value,
357381
)
358382
except Exception:
383+
logger.debug("Falling back from third-party blockwise kernel to local dequant path.", exc_info=True)
359384
out = None
360385

361386
if out is None:

python/sgl_jax/srt/layers/linear.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
"""Linear layers."""
33

4-
import math
54
from collections.abc import Sequence
65
from functools import partial
76

@@ -165,7 +164,13 @@ def from_linear(
165164
)
166165
bias = linear.bias.value if linear.bias is not None else None
167166
else:
168-
weight_q = weight.T.astype(weight_dtype)
167+
if weight.dtype != weight_dtype:
168+
raise ValueError(
169+
"QuantizedLinear.from_linear(..., is_static_input=True) requires "
170+
"pre-quantized concrete weights or abstract shapes. "
171+
f"Got weight.dtype={weight.dtype}, expected {weight_dtype}."
172+
)
173+
weight_q = weight.T
169174
if effective_weight_block_size is not None and len(effective_weight_block_size) == 2:
170175
block_n, block_k = int(effective_weight_block_size[0]), int(effective_weight_block_size[1])
171176
out_blocks = (weight_q.shape[0] + block_n - 1) // block_n
@@ -192,7 +197,7 @@ def from_linear(
192197
weight_q=weight_q, weight_scale=weight_scale, bias=bias,
193198
activation_dtype=activation_dtype, mesh=linear.mesh,
194199
kernel_axes=linear.kernel_axes,
195-
skip_bias_add=linear.skip_bias_add or linear.bias is None,
200+
skip_bias_add=linear.skip_bias_add,
196201
params_dtype=linear.params_dtype, weight_block_size=effective_weight_block_size,
197202
scope_name=f"quantized_{linear.name}",
198203
)
@@ -212,22 +217,13 @@ def __call__(self, x: jax.Array) -> tuple[jax.Array, jax.Array | None]:
212217
in_specs = (P(None, input_axis), P(output_axis, input_axis), w_scale_spec)
213218
out_specs = P(None, output_axis)
214219

215-
# Handle block size inference
216-
effective_weight_block_size = self.weight_block_size
217-
if scale_val.ndim == 2 and self.weight_block_size is not None:
218-
global_out_size, global_in_size = self.weight_q.value.shape
219-
inferred_bs_out = math.ceil(global_out_size / scale_val.shape[0])
220-
inferred_bs_in = math.ceil(global_in_size / scale_val.shape[1])
221-
if (inferred_bs_out != self.weight_block_size[0] or inferred_bs_in != self.weight_block_size[1]):
222-
effective_weight_block_size = (inferred_bs_out, inferred_bs_in)
223-
224220
output = shard_map(
225221
partial(
226222
xla_quantized_matmul_local,
227223
quantize_activation=quantize_activation,
228224
reduce_axis=input_axis,
229225
compute_dtype=self.compute_dtype,
230-
weight_block_size=effective_weight_block_size,
226+
weight_block_size=self.weight_block_size,
231227
activation_quant_dtype=self.activation_dtype,
232228
),
233229
mesh=self.mesh, in_specs=in_specs, out_specs=out_specs, check_vma=False,

python/sgl_jax/srt/layers/moe.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,25 @@ def _normalize_scale_for_gmm(
329329
num_experts, out_dim, in_dim = weight.shape
330330

331331
if scale.ndim == 4:
332+
if scale.shape[0] != num_experts or scale.shape[2] != 1 or scale.shape[3] != out_dim:
333+
raise ValueError(
334+
f"Unsupported {scale_name} shape {scale.shape} for weight shape {weight.shape}. "
335+
"Expected 4D GMM scale layout [E, k_blocks, 1, out_dim]."
336+
)
337+
if self.weight_block_size is None:
338+
if scale.shape[1] != 1:
339+
raise ValueError(
340+
f"Unsupported {scale_name} shape {scale.shape} for weight shape {weight.shape}. "
341+
"Per-channel 4D GMM scales must have k_blocks=1."
342+
)
343+
else:
344+
block_size_k = int(self.weight_block_size[1])
345+
expected_k_blocks = (in_dim + block_size_k - 1) // block_size_k
346+
if scale.shape[1] not in (1, expected_k_blocks):
347+
raise ValueError(
348+
f"Unsupported {scale_name} shape {scale.shape} for weight shape {weight.shape}. "
349+
f"Expected k_blocks dimension to be 1 or {expected_k_blocks}."
350+
)
332351
return scale
333352

334353
if scale.ndim == 2 and scale.shape == (num_experts, out_dim):

python/sgl_jax/srt/utils/quantization/quantization_utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from jax.sharding import NamedSharding, PartitionSpec as P
1111

1212
from sgl_jax.srt.configs.model_config import ModelConfig
13-
from sgl_jax.srt.configs.quantization_config import DTYPE_MAP
13+
from sgl_jax.srt.configs.quantization_config import DTYPE_MAP, _normalize_weight_block_size
1414

1515
logger = logging.getLogger(__name__)
1616

@@ -103,6 +103,7 @@ def apply_linear_quantization(
103103
if "weight_block_size" in rule
104104
else getattr(quant_config, "weight_block_size", None)
105105
)
106+
weight_block_size = _normalize_weight_block_size(weight_block_size)
106107

107108
# Convert string dtypes to jnp dtypes
108109
weight_dtype = DTYPE_MAP.get(weight_dtype_str)
@@ -120,7 +121,7 @@ def apply_linear_quantization(
120121
}
121122
)
122123

123-
ignored_layers = getattr(quant_config, "ignored_layers", None) or []
124+
ignored_layers = quant_config.ignored_layers or []
124125

125126
def _find_matching_rule(path: str):
126127
"""Find the first rule that matches the given module path."""
@@ -147,8 +148,11 @@ def _replace_linear_recursive(obj, path: str = "", visited: set | None = None):
147148
if isinstance(attr_value, LinearBase):
148149
# Check if this path matches any rule
149150
dot_path = child_path.replace("/", ".")
150-
if any(dot_path.endswith(ignored) or ignored in dot_path for ignored in ignored_layers):
151-
logger.info("Skipping %s - in ignored_layers", child_path)
151+
if any(
152+
dot_path == ignored or dot_path.endswith(f".{ignored}")
153+
for ignored in ignored_layers
154+
):
155+
logger.info("Skipping %s - in ignored_layers", dot_path)
152156
continue
153157

154158
rule = _find_matching_rule(child_path)

0 commit comments

Comments
 (0)