Skip to content

Commit 5de6d94

Browse files
authored
Gather qmm batched kernel and refactoring of quantized (#2078)
1 parent 99eefd2 commit 5de6d94

File tree

15 files changed

+1482
-452
lines changed

15 files changed

+1482
-452
lines changed

benchmarks/python/gather_mm_bench.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright © 2023-2024 Apple Inc.
1+
# Copyright © 2025 Apple Inc.
22

33
import mlx.core as mx
44
from time_utils import time_fn
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright © 2025 Apple Inc.
2+
3+
import mlx.core as mx
4+
from time_utils import time_fn
5+
6+
N = 1024
7+
D = 1024
8+
M = 1024
9+
E = 32
10+
I = 4
11+
12+
13+
def gather_sort(x, indices):
14+
N, M = indices.shape
15+
indices = indices.flatten()
16+
order = mx.argsort(indices)
17+
inv_order = mx.argsort(order)
18+
return x.flatten(0, -3)[order // M], indices[order], inv_order
19+
20+
21+
def scatter_unsort(x, inv_order, shape=None):
22+
x = x[inv_order]
23+
if shape is not None:
24+
x = mx.unflatten(x, 0, shape)
25+
return x
26+
27+
28+
def gather_mm_simulate(x, w, indices):
29+
x, idx, inv_order = gather_sort(x, indices)
30+
for i in range(2):
31+
y = mx.concatenate(
32+
[
33+
mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True)
34+
for i, j in enumerate(idx.tolist())
35+
],
36+
axis=0,
37+
)
38+
x = y[:, None]
39+
x = scatter_unsort(x, inv_order, indices.shape)
40+
return x
41+
42+
43+
def time_gather_qmm():
44+
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
45+
w1 = mx.random.normal((E, M, D)) / 1024**0.5
46+
w2 = mx.random.normal((E, D, M)) / 1024**0.5
47+
w1 = mx.quantize(w1)
48+
w2 = mx.quantize(w2)
49+
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
50+
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
51+
mx.eval(x, w1, w2, indices, sorted_indices)
52+
53+
def gather_mm(x, w1, w2, indices, sort):
54+
idx = indices
55+
inv_order = None
56+
if sort:
57+
x, idx, inv_order = gather_sort(x, indices)
58+
x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort)
59+
x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort)
60+
if sort:
61+
x = scatter_unsort(x, inv_order, indices.shape)
62+
return x
63+
64+
time_fn(gather_mm, x, w1, w2, indices, False)
65+
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
66+
time_fn(gather_mm, x, w1, w2, indices, True)
67+
68+
x = mx.random.normal((N * I, D)) / 1024**0.5
69+
w1 = mx.random.normal((M, D)) / 1024**0.5
70+
w2 = mx.random.normal((D, M)) / 1024**0.5
71+
w1 = mx.quantize(w1)
72+
w2 = mx.quantize(w2)
73+
mx.eval(x, w1, w2)
74+
75+
def equivalent_matmul(x, w1, w2):
76+
x = mx.quantized_matmul(x, *w1, transpose=True)
77+
x = mx.quantized_matmul(x, *w2, transpose=True)
78+
return x
79+
80+
time_fn(equivalent_matmul, x, w1, w2)
81+
82+
83+
if __name__ == "__main__":
84+
time_gather_qmm()

mlx/backend/metal/jit_kernels.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -752,4 +752,43 @@ MTL::ComputePipelineState* get_quantized_kernel(
752752
return d.get_kernel(kernel_name, lib);
753753
}
754754

755+
MTL::ComputePipelineState* get_gather_qmm_kernel(
756+
metal::Device& d,
757+
const std::string& kernel_name,
758+
const std::string& hash_name,
759+
const metal::MTLFCList& func_consts,
760+
const array& x,
761+
int group_size,
762+
int bits,
763+
int bm,
764+
int bn,
765+
int bk,
766+
int wm,
767+
int wn,
768+
bool transpose) {
769+
const auto& lib_name = kernel_name;
770+
auto lib = d.get_library(lib_name, [&]() {
771+
std::string kernel_source;
772+
concatenate(
773+
kernel_source,
774+
metal::utils(),
775+
metal::gemm(),
776+
metal::quantized(),
777+
get_template_definition(
778+
lib_name,
779+
"gather_qmm_rhs",
780+
get_type_string(x.dtype()),
781+
group_size,
782+
bits,
783+
bm,
784+
bn,
785+
bk,
786+
wm,
787+
wn,
788+
transpose));
789+
return kernel_source;
790+
});
791+
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
792+
}
793+
755794
} // namespace mlx::core

mlx/backend/metal/kernels.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,21 @@ MTL::ComputePipelineState* get_quantized_kernel(
224224
const std::string& kernel_name,
225225
const std::string& template_def);
226226

227+
MTL::ComputePipelineState* get_gather_qmm_kernel(
228+
metal::Device& d,
229+
const std::string& kernel_name,
230+
const std::string& hash_name,
231+
const metal::MTLFCList& func_consts,
232+
const array& x,
233+
int group_size,
234+
int bits,
235+
int bm,
236+
int bn,
237+
int bk,
238+
int wm,
239+
int wn,
240+
bool transpose);
241+
227242
// Create a GPU kernel template definition for JIT compilation
228243
template <typename... Args>
229244
std::string

0 commit comments

Comments
 (0)