Skip to content

Commit 8536859

Browse files
committed
add test cases
1 parent a9f778e commit 8536859

16 files changed

Lines changed: 825 additions & 165 deletions

File tree

python/sgl_jax/srt/kernels/quantized_matmul/3rd_quantized_matmul/blockwise_kernel.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# Adapted from https://github.com/vllm-project/tpu-inference/blob/main/tpu_inference/kernels/quantized_matmul/blockwise_kernel.py
12
# SPDX-License-Identifier: Apache-2.0
23
"""Quantized matmul kernel with blockwise quantization support."""
34

@@ -83,7 +84,7 @@ def quantized_matmul_kernel(
8384
padded_n_out = next_multiple(orig_n_out, out_block_size)
8485
if orig_n_out < padded_n_out:
8586
w_q = jnp.pad(w_q, ((0, padded_n_out - orig_n_out), (0, 0)))
86-
w_scale = jnp.pad(w_scale, (0, padded_n_out - orig_n_out))
87+
w_scale = jnp.pad(w_scale, ((0, 0), (0, 0), (0, padded_n_out - orig_n_out)))
8788
padded_n_in = next_multiple(orig_n_in, in_block_size)
8889
if orig_n_in < padded_n_in:
8990
x = jnp.pad(x, ((0, 0), (0, padded_n_in - orig_n_in)))
@@ -135,16 +136,20 @@ def quantized_matmul_kernel(
135136
def kernel(lhs_ref, rhs_ref, w_scales_ref, out_ref, acc_scratch):
136137
pid_k = pl.program_id(2)
137138
is_first_step = pid_k == 0
138-
is_last_step = pid_k == (orig_n_in // in_block_size - 1)
139+
is_last_step = pid_k == (n_in - 1)
139140

140141
def accum(is_first_step, is_last_step):
141142
accumulators = [None] * steps_n
142143

143144
for i in range(steps_k):
144145
k_start, k_end = i * block_size, (i + 1) * block_size
145-
lhs_sub = lhs_ref[:, k_start:k_end].astype(jnp.float32)
146-
lhs_q, lhs_scale = util.quantize_block(lhs_sub, 1, x_q_dtype)
147-
lhs_scale = lhs_scale.astype(acc_dtype)
146+
if quantize_activation:
147+
lhs_sub = lhs_ref[:, k_start:k_end].astype(jnp.float32)
148+
lhs_q, lhs_scale = util.quantize_block(lhs_sub, 1, x_q_dtype)
149+
lhs_scale = lhs_scale.astype(acc_dtype)
150+
else:
151+
lhs_q = lhs_ref[:, k_start:k_end]
152+
lhs_scale = None
148153

149154
rhs_q_full = rhs_ref[:, k_start:k_end]
150155
rhs_scale_full = w_scales_ref[i, :, :].astype(acc_dtype)
@@ -166,7 +171,8 @@ def accum(is_first_step, is_last_step):
166171
preferred_element_type=preferred_element_type,
167172
)
168173
res = dot_res.astype(acc_dtype)
169-
res = res * lhs_scale
174+
if lhs_scale is not None:
175+
res = res * lhs_scale
170176
res = res * rhs_scale_slice
171177
if i == 0:
172178
accumulators[j] = res

python/sgl_jax/srt/kernels/quantized_matmul/3rd_quantized_matmul/kernel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# Adapted from https://github.com/vllm-project/tpu-inference/blob/main/tpu_inference/kernels/quantized_matmul/kernel.py
12
# SPDX-License-Identifier: Apache-2.0
23
"""Quantized matmul kernel."""
34

@@ -184,7 +185,7 @@ def quantized_matmul_kernel(
184185
padded_n_out = next_multiple(orig_n_out, out_block_size)
185186
if orig_n_out < padded_n_out:
186187
w_q = jnp.pad(w_q, ((0, padded_n_out - orig_n_out), (0, 0)))
187-
w_scale = jnp.pad(w_scale, (0, padded_n_out - orig_n_out))
188+
w_scale = jnp.pad(w_scale, ((0, 0), (0, padded_n_out - orig_n_out)))
188189
padded_n_in = next_multiple(orig_n_in, in_block_size)
189190
if orig_n_in < padded_n_in:
190191
x = jnp.pad(x, ((0, 0), (0, padded_n_in - orig_n_in)))

python/sgl_jax/srt/kernels/quantized_matmul/3rd_quantized_matmul/tuned_block_sizes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# Adapted from https://github.com/vllm-project/tpu-inference/blob/main/tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py
12
# SPDX-License-Identifier: Apache-2.0
23
"""Tuned block sizes for quantized matmul kernel."""
34

python/sgl_jax/srt/kernels/quantized_matmul/3rd_quantized_matmul/util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# Adapted from https://github.com/vllm-project/tpu-inference/blob/main/tpu_inference/kernels/quantized_matmul/util.py
12
# SPDX-License-Identifier: Apache-2.0
23
"""Utility functions for quantized matmul kernel."""
34
from typing import Any, Callable

python/sgl_jax/srt/kernels/quantized_matmul/kernel.py

Lines changed: 100 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
_TRIED_LOADING_BLOCKWISE_3RD_KERNEL = False
1616
_BLOCKWISE_3RD_TUNED_VALUE_CLS = None
1717
_BLOCKWISE_3RD_GET_TUNED_BLOCK_SIZES = None
18+
_BLOCKWISE_3RD_TUNED_BLOCK_SIZES = None
1819
_TRIED_LOADING_BLOCKWISE_3RD_TUNING = False
1920

2021

@@ -39,22 +40,33 @@ def _get_blockwise_3rd_tuning_api():
3940
"""Lazily load third-party tuned-size helpers for blockwise kernel."""
4041
global _BLOCKWISE_3RD_TUNED_VALUE_CLS
4142
global _BLOCKWISE_3RD_GET_TUNED_BLOCK_SIZES
43+
global _BLOCKWISE_3RD_TUNED_BLOCK_SIZES
4244
global _TRIED_LOADING_BLOCKWISE_3RD_TUNING
4345

4446
if _TRIED_LOADING_BLOCKWISE_3RD_TUNING:
45-
return _BLOCKWISE_3RD_TUNED_VALUE_CLS, _BLOCKWISE_3RD_GET_TUNED_BLOCK_SIZES
47+
return (
48+
_BLOCKWISE_3RD_TUNED_VALUE_CLS,
49+
_BLOCKWISE_3RD_GET_TUNED_BLOCK_SIZES,
50+
_BLOCKWISE_3RD_TUNED_BLOCK_SIZES,
51+
)
4652
_TRIED_LOADING_BLOCKWISE_3RD_TUNING = True
4753

4854
try:
4955
package = __package__ or "sgl_jax.srt.kernels.quantized_matmul"
5056
module = importlib.import_module(f"{package}.3rd_quantized_matmul.tuned_block_sizes")
5157
_BLOCKWISE_3RD_TUNED_VALUE_CLS = getattr(module, "TunedValue", None)
5258
_BLOCKWISE_3RD_GET_TUNED_BLOCK_SIZES = getattr(module, "get_tuned_block_sizes", None)
59+
_BLOCKWISE_3RD_TUNED_BLOCK_SIZES = getattr(module, "TUNED_BLOCK_SIZES", None)
5360
except Exception:
5461
_BLOCKWISE_3RD_TUNED_VALUE_CLS = None
5562
_BLOCKWISE_3RD_GET_TUNED_BLOCK_SIZES = None
63+
_BLOCKWISE_3RD_TUNED_BLOCK_SIZES = None
5664

57-
return _BLOCKWISE_3RD_TUNED_VALUE_CLS, _BLOCKWISE_3RD_GET_TUNED_BLOCK_SIZES
65+
return (
66+
_BLOCKWISE_3RD_TUNED_VALUE_CLS,
67+
_BLOCKWISE_3RD_GET_TUNED_BLOCK_SIZES,
68+
_BLOCKWISE_3RD_TUNED_BLOCK_SIZES,
69+
)
5870

5971

6072
def _next_multiple(x: int, m: int) -> int:
@@ -63,6 +75,72 @@ def _next_multiple(x: int, m: int) -> int:
6375
return ((x + m - 1) // m) * m
6476

6577

78+
def _floor_multiple(x: int, m: int) -> int:
79+
if m <= 0:
80+
return x
81+
return max(m, (x // m) * m)
82+
83+
84+
def _nearest_power_of_two_multiple(x: int, base: int, upper_bound: int) -> int:
85+
if base <= 0:
86+
return x
87+
88+
x = max(base, x)
89+
units = max(1, x // base)
90+
lower_units = 1 << (units.bit_length() - 1)
91+
upper_units = lower_units if lower_units == units else lower_units << 1
92+
93+
def _candidate(units_value: int) -> int:
94+
return units_value * base
95+
96+
lower = _candidate(lower_units)
97+
upper = _candidate(upper_units)
98+
candidates = [value for value in (lower, upper) if value <= upper_bound]
99+
if not candidates:
100+
candidates = [lower]
101+
102+
return min(candidates, key=lambda value: (abs(value - x), -value))
103+
104+
105+
def _iter_blockwise_tuned_candidates(
106+
tuned_block_sizes: dict | None,
107+
n_batch: int,
108+
n_out: int,
109+
n_in: int,
110+
x_q_dtype: jnp.dtype,
111+
w_q_dtype: jnp.dtype,
112+
):
113+
if not tuned_block_sizes:
114+
return []
115+
116+
x_q_dtype_name = jnp.dtype(x_q_dtype).name
117+
w_q_dtype_name = jnp.dtype(w_q_dtype).name
118+
compatible_x_dtype_names = [x_q_dtype_name]
119+
if jnp.issubdtype(w_q_dtype, jnp.integer) and x_q_dtype_name != "int8":
120+
compatible_x_dtype_names.append("int8")
121+
122+
candidates = []
123+
for key, value in tuned_block_sizes.items():
124+
if key.w_q_dtype != w_q_dtype_name:
125+
continue
126+
if key.x_q_dtype not in compatible_x_dtype_names:
127+
continue
128+
129+
score = (
130+
compatible_x_dtype_names.index(key.x_q_dtype),
131+
key.n_in != n_in,
132+
abs(key.n_in - n_in),
133+
key.n_batch != n_batch,
134+
abs(key.n_batch - n_batch),
135+
key.n_out != n_out,
136+
abs(key.n_out - n_out),
137+
)
138+
candidates.append((score, value))
139+
140+
candidates.sort(key=lambda item: item[0])
141+
return [value for _, value in candidates]
142+
143+
66144
def _get_safe_blockwise_tuned_value(
67145
n_batch: int,
68146
n_out: int,
@@ -72,12 +150,22 @@ def _get_safe_blockwise_tuned_value(
72150
block_size_in: int,
73151
):
74152
"""Build a safe tuned value for third-party blockwise kernel on TPU."""
75-
tuned_value_cls, get_tuned_block_sizes = _get_blockwise_3rd_tuning_api()
153+
tuned_value_cls, get_tuned_block_sizes, tuned_block_sizes = _get_blockwise_3rd_tuning_api()
76154
if tuned_value_cls is None:
77155
return None
78156

79157
tuned = None
80-
if get_tuned_block_sizes is not None:
158+
compatible_candidates = _iter_blockwise_tuned_candidates(
159+
tuned_block_sizes=tuned_block_sizes,
160+
n_batch=n_batch,
161+
n_out=n_out,
162+
n_in=n_in,
163+
x_q_dtype=x_q_dtype,
164+
w_q_dtype=w_q_dtype,
165+
)
166+
if compatible_candidates:
167+
tuned = compatible_candidates[0]
168+
elif get_tuned_block_sizes is not None:
81169
try:
82170
tuned = get_tuned_block_sizes(
83171
n_batch=n_batch,
@@ -94,10 +182,17 @@ def _get_safe_blockwise_tuned_value(
94182
n_lane_multiplier = max(1, int(tuned.n_lane_multiplier))
95183
compute_tile_n = 256 * n_lane_multiplier
96184

97-
batch_block_size = max(1, int(tuned.batch_block_size))
185+
batch_block_size = max(1, min(int(tuned.batch_block_size), int(n_batch)))
98186
out_block_size = _next_multiple(max(int(tuned.out_block_size), compute_tile_n), compute_tile_n)
187+
out_block_size = min(out_block_size, _floor_multiple(int(n_out), compute_tile_n))
188+
out_block_size = _nearest_power_of_two_multiple(
189+
out_block_size,
190+
compute_tile_n,
191+
_floor_multiple(int(n_out), compute_tile_n),
192+
)
99193
in_block_size = max(int(tuned.in_block_size), int(block_size_in))
100194
in_block_size = _next_multiple(in_block_size, int(block_size_in))
195+
in_block_size = min(in_block_size, _floor_multiple(int(n_in), int(block_size_in)))
101196

102197
return tuned_value_cls(batch_block_size, out_block_size, in_block_size, n_lane_multiplier)
103198

python/sgl_jax/srt/layers/linear.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
import jax
99
import jax.numpy as jnp
1010
from flax import nnx
11+
from jax import shard_map
1112
from jax.sharding import NamedSharding, PartitionSpec as P
12-
from jax.experimental.shard_map import shard_map
1313

1414
from sgl_jax.srt.kernels.quantized_matmul.kernel import xla_quantized_matmul_local
1515
from sgl_jax.srt.utils.quantization.quantization_utils import quantize_tensor
@@ -61,12 +61,35 @@ def __init__(
6161

6262
def __call__(self, x: jax.Array) -> jax.Array | tuple[jax.Array, jax.Array]:
6363
"""Forward pass."""
64-
out = jnp.dot(x, self.weight.value)
64+
x_2d = x.reshape(-1, x.shape[-1]) if x.ndim > 2 else x
65+
66+
if self.mesh is not None and self.kernel_axes is not None:
67+
input_axis, output_axis = self.kernel_axes[0], self.kernel_axes[1]
68+
69+
def _sharded_dot(lhs: jax.Array, rhs: jax.Array) -> jax.Array:
70+
y = jnp.dot(lhs, rhs)
71+
if input_axis is not None:
72+
y = jax.lax.psum(y, input_axis)
73+
return y
74+
75+
out = shard_map(
76+
_sharded_dot,
77+
mesh=self.mesh,
78+
in_specs=(P(None, input_axis), P(input_axis, output_axis)),
79+
out_specs=P(None, output_axis),
80+
check_vma=False,
81+
)(x_2d, self.weight.value)
82+
else:
83+
out = jnp.dot(x_2d, self.weight.value)
84+
85+
if x.ndim > 2:
86+
out = out.reshape(x.shape[:-1] + (out.shape[-1],))
87+
88+
if self.skip_bias_add:
89+
return out, (self.bias.value if self.bias is not None else None)
6590
if self.bias is not None:
66-
if self.skip_bias_add:
67-
return out, self.bias.value
6891
return out + self.bias.value
69-
return out
92+
return out, None
7093

7194

7295
class QuantizedLinear(nnx.Module):
@@ -168,7 +191,8 @@ def from_linear(
168191
return cls(
169192
weight_q=weight_q, weight_scale=weight_scale, bias=bias,
170193
activation_dtype=activation_dtype, mesh=linear.mesh,
171-
kernel_axes=linear.kernel_axes, skip_bias_add=linear.skip_bias_add,
194+
kernel_axes=linear.kernel_axes,
195+
skip_bias_add=linear.skip_bias_add or linear.bias is None,
172196
params_dtype=linear.params_dtype, weight_block_size=effective_weight_block_size,
173197
scope_name=f"quantized_{linear.name}",
174198
)
@@ -212,7 +236,8 @@ def __call__(self, x: jax.Array) -> jax.Array | tuple[jax.Array, jax.Array]:
212236
if x.ndim > 2:
213237
output = output.reshape(x.shape[:-1] + (output.shape[-1],))
214238

239+
if self.skip_bias_add:
240+
return output, (self.bias.value if self.bias is not None else None)
215241
if self.bias is not None:
216-
if self.skip_bias_add: return output, self.bias.value
217242
return output + self.bias.value
218243
return output

0 commit comments

Comments
 (0)