Skip to content

Commit 1cad65b

Browse files
committed
using tpu-inference quantized kernel
1 parent feed48b commit 1cad65b

6 files changed

Lines changed: 1786 additions & 37 deletions

File tree

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .blockwise_kernel import quantized_matmul_kernel as quantized_matmul
16+
17+
__all__ = ["quantized_matmul"]
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Quantized matmul kernel with blockwise quantization support."""
3+
4+
import jax
5+
import jax.numpy as jnp
6+
from jax.experimental import pallas as pl
7+
from jax.experimental.pallas import tpu as pltpu
8+
9+
from . import util
10+
from .tuned_block_sizes import (
11+
TunedValue, get_device_vmem_limit, get_tuned_block_sizes)
12+
from .util import (get_kernel_name,
13+
next_multiple,
14+
unfold_args)
15+
16+
quantize_tensor = util.quantize_tensor
17+
MXU_SIZE = 256
18+
19+
20+
@jax.jit(static_argnames=[
21+
"block_size",
22+
"x_q_dtype",
23+
"tuned_value",
24+
])
25+
def quantized_matmul_kernel(
26+
x: jax.Array, # [bs, n_in]
27+
w_q: jax.Array, # [n_out, n_in]
28+
w_scale: jax.Array, # [n_in // block_size, 1, n_out]
29+
w_zp: jax.Array | None = None, # [n_out]
30+
block_size: int | None = None,
31+
x_q_dtype: jnp.dtype | None = None,
32+
*,
33+
tuned_value: TunedValue | None = None,
34+
) -> jax.Array:
35+
"""Quantized matmul kernel with blockwise support.
36+
37+
Args:
38+
x: Input unquantized array.
39+
w_q: Weight quantized array. [n_output_features, n_input_features]
40+
w_scale: Weight quantization scale. [n_input_features // block_size, 1, n_output_features]
41+
w_zp: Weight zero point for asymmetric quantization.
42+
block_size: Block size for subchannel quantization.
43+
x_q_dtype: Quantization type of the input. If None or if the value is the
44+
same as x.dtype, then no quantization is applied.
45+
tuned_value: Kernel tuned values for optimal performance.
46+
47+
Returns:
48+
Quantized matmul result.
49+
"""
50+
51+
if block_size is None:
52+
raise ValueError("Block size was not specified.")
53+
if w_zp is not None:
54+
raise NotImplementedError("zero_point is not supported.")
55+
56+
if x_q_dtype is None:
57+
x_q_dtype = x.dtype
58+
quantize_activation = x_q_dtype != x.dtype
59+
60+
orig_n_batch, orig_n_in = x.shape
61+
orig_n_out, *_ = w_q.shape
62+
63+
if tuned_value is None:
64+
tuned_value = get_tuned_block_sizes(
65+
n_batch=orig_n_batch,
66+
n_out=orig_n_out,
67+
n_in=orig_n_in,
68+
x_q_dtype=jnp.dtype(x_q_dtype).name,
69+
w_q_dtype=jnp.dtype(w_q.dtype).name,
70+
)
71+
batch_block_size = tuned_value.batch_block_size
72+
out_block_size = tuned_value.out_block_size
73+
in_block_size = tuned_value.in_block_size
74+
n_lane_multiplier = tuned_value.n_lane_multiplier
75+
# The num_blocks should become 1 in case of channelwise.
76+
block_size = tuned_value.in_block_size if block_size == orig_n_in else block_size
77+
78+
# Pad the inputs to be multiple of block size.
79+
padded_n_batch = next_multiple(orig_n_batch, batch_block_size)
80+
if orig_n_batch < padded_n_batch:
81+
x = jnp.pad(x, ((0, padded_n_batch - orig_n_batch), (0, 0)))
82+
83+
padded_n_out = next_multiple(orig_n_out, out_block_size)
84+
if orig_n_out < padded_n_out:
85+
w_q = jnp.pad(w_q, ((0, padded_n_out - orig_n_out), (0, 0)))
86+
w_scale = jnp.pad(w_scale, (0, padded_n_out - orig_n_out))
87+
padded_n_in = next_multiple(orig_n_in, in_block_size)
88+
if orig_n_in < padded_n_in:
89+
x = jnp.pad(x, ((0, 0), (0, padded_n_in - orig_n_in)))
90+
w_q = jnp.pad(w_q, ((0, 0), (0, padded_n_in - orig_n_in)))
91+
92+
if w_scale.dtype != jnp.float32:
93+
w_scale = w_scale.astype(jnp.float32)
94+
95+
n_batch = padded_n_batch // batch_block_size
96+
n_out = padded_n_out // out_block_size
97+
n_in = padded_n_in // in_block_size
98+
99+
save_acc = n_in > 1
100+
# Remove redundant input quantization logic by caching quantized input. For
101+
# best performance, only enable this behavior when single input block is
102+
# used per batch.
103+
save_x_q = quantize_activation and n_in == 1 and n_out > 1
104+
105+
# TODO(amandaliang): Make this configurable.
106+
acc_dtype = jnp.bfloat16
107+
if quantize_activation and jnp.issubdtype(w_q.dtype, jnp.integer):
108+
acc_dtype = jnp.int32
109+
110+
vmem_limit_bytes = util.get_vmem_limit(
111+
n_batch=n_batch,
112+
n_out=n_out,
113+
n_in=n_in,
114+
batch_block_size=batch_block_size,
115+
out_block_size=out_block_size,
116+
in_block_size=in_block_size,
117+
x_dtype=x.dtype,
118+
x_q_dtype=x_q_dtype,
119+
w_q_dtype=w_q.dtype,
120+
scale_dtype=jnp.float32,
121+
out_dtype=x.dtype,
122+
acc_dtype=acc_dtype,
123+
save_acc=save_acc,
124+
save_x_q=save_x_q,
125+
upper_limit_bytes=get_device_vmem_limit(),
126+
)
127+
128+
steps_k = in_block_size // block_size
129+
# n_lane_multiplier > 1 could improve perf by reducing loop overhead and increasing instruction-level parallelism,
130+
# allowing the compiler to overlap output fusion and packing overhead with MXU computation
131+
# TODO(amandaliang): use pltpu.get_tpu_info().mxu_column_size when JAX version is newer
132+
compute_tile_n = MXU_SIZE * n_lane_multiplier
133+
steps_n = out_block_size // compute_tile_n
134+
135+
def kernel(lhs_ref, rhs_ref, w_scales_ref, out_ref, acc_scratch):
136+
pid_k = pl.program_id(2)
137+
is_first_step = pid_k == 0
138+
is_last_step = pid_k == (orig_n_in // in_block_size - 1)
139+
140+
def accum(is_first_step, is_last_step):
141+
accumulators = [None] * steps_n
142+
143+
for i in range(steps_k):
144+
k_start, k_end = i * block_size, (i + 1) * block_size
145+
lhs_sub = lhs_ref[:, k_start:k_end].astype(jnp.float32)
146+
lhs_q, lhs_scale = util.quantize_block(lhs_sub, 1, x_q_dtype)
147+
lhs_scale = lhs_scale.astype(acc_dtype)
148+
149+
rhs_q_full = rhs_ref[:, k_start:k_end]
150+
rhs_scale_full = w_scales_ref[i, :, :].astype(acc_dtype)
151+
152+
for j in range(steps_n):
153+
n_start, n_end = j * compute_tile_n, (j +
154+
1) * compute_tile_n
155+
156+
rhs_q_slice = rhs_q_full[n_start:n_end, :]
157+
rhs_scale_slice = rhs_scale_full[:, n_start:n_end]
158+
if jnp.issubdtype(x_q_dtype, jnp.integer):
159+
preferred_element_type = jnp.int32
160+
else:
161+
preferred_element_type = jnp.float32
162+
dot_res = jax.lax.dot_general(
163+
lhs_q,
164+
rhs_q_slice,
165+
(((1, ), (1, )), ((), ())),
166+
preferred_element_type=preferred_element_type,
167+
)
168+
res = dot_res.astype(acc_dtype)
169+
res = res * lhs_scale
170+
res = res * rhs_scale_slice
171+
if i == 0:
172+
accumulators[j] = res
173+
else:
174+
accumulators[j] += res
175+
176+
acc_block = jnp.concatenate(accumulators, axis=1)
177+
178+
if not is_first_step:
179+
acc_block += acc_scratch[...]
180+
181+
if is_last_step:
182+
out_ref[...] = acc_block.astype(out_ref.dtype)
183+
else:
184+
acc_scratch[...] = acc_block
185+
186+
unfold_args((is_first_step, is_last_step), (), accum)
187+
188+
kernel = pl.pallas_call(
189+
kernel,
190+
grid_spec=pltpu.PrefetchScalarGridSpec(
191+
num_scalar_prefetch=0,
192+
in_specs=[
193+
pl.BlockSpec(
194+
(batch_block_size, in_block_size),
195+
lambda b, o, i: (b, i),
196+
memory_space=pltpu.VMEM,
197+
), # x
198+
pl.BlockSpec(
199+
(out_block_size, in_block_size),
200+
lambda b, o, i: (o, i),
201+
memory_space=pltpu.VMEM,
202+
), # w_q
203+
pl.BlockSpec(
204+
(steps_k, 1, out_block_size),
205+
lambda _, o, i: (i, 0, o),
206+
memory_space=pltpu.VMEM,
207+
),
208+
], # w_scale
209+
out_specs=pl.BlockSpec((batch_block_size, out_block_size),
210+
lambda b, o, i: (b, o)),
211+
scratch_shapes=[
212+
pltpu.VMEM((batch_block_size, out_block_size), jnp.bfloat16)
213+
],
214+
grid=(n_batch, n_out, n_in),
215+
),
216+
out_shape=jax.ShapeDtypeStruct((padded_n_batch, padded_n_out),
217+
x.dtype),
218+
compiler_params=pltpu.CompilerParams(
219+
dimension_semantics=("parallel", "parallel", "arbitrary"),
220+
vmem_limit_bytes=vmem_limit_bytes,
221+
),
222+
)
223+
224+
util.validate_inputs(
225+
x=x,
226+
w_q=w_q,
227+
w_scale=w_scale,
228+
x_abs_max=None,
229+
x_q_dtype=x_q_dtype,
230+
batch_block_size=batch_block_size,
231+
out_block_size=out_block_size,
232+
in_block_size=in_block_size,
233+
)
234+
235+
# The named_scope is used for autotune.
236+
kernel_name = get_kernel_name(tuned_value)
237+
with jax.named_scope(kernel_name):
238+
out = kernel(x, w_q, w_scale)
239+
240+
return out[:orig_n_batch, :orig_n_out]

0 commit comments

Comments
 (0)