Skip to content

Commit 26be608

Browse files
authored
Add split_k qvm for long context (#1564)
* Add splitk qvm * configurable splitk * tuning * remove extra instantiation * remove refactor * separate test * cpu tolerance
1 parent 248431e commit 26be608

File tree

4 files changed

+220
-4
lines changed

4 files changed

+220
-4
lines changed

mlx/backend/metal/kernels/quantized.h

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -650,8 +650,8 @@ METAL_FUNC void qvm_impl(
650650
const device T* biases,
651651
const device T* x,
652652
device T* y,
653-
const constant int& in_vec_size,
654-
const constant int& out_vec_size,
653+
const int in_vec_size,
654+
const int out_vec_size,
655655
uint3 tid [[threadgroup_position_in_grid]],
656656
uint simd_gid [[simdgroup_index_in_threadgroup]],
657657
uint simd_lid [[thread_index_in_simdgroup]]) {
@@ -1298,6 +1298,61 @@ template <typename T, const int group_size, const int bits, bool batched>
12981298
simd_lid);
12991299
}
13001300

1301+
template <typename T, const int group_size, const int bits, int split_k = 32>
1302+
[[kernel]] void qvm_split_k(
1303+
const device uint32_t* w [[buffer(0)]],
1304+
const device T* scales [[buffer(1)]],
1305+
const device T* biases [[buffer(2)]],
1306+
const device T* x [[buffer(3)]],
1307+
device T* y [[buffer(4)]],
1308+
const constant int& in_vec_size [[buffer(5)]],
1309+
const constant int& out_vec_size [[buffer(6)]],
1310+
const constant int& x_batch_ndims [[buffer(7)]],
1311+
const constant int* x_shape [[buffer(8)]],
1312+
const constant size_t* x_strides [[buffer(9)]],
1313+
const constant int& w_batch_ndims [[buffer(10)]],
1314+
const constant int* w_shape [[buffer(11)]],
1315+
const constant size_t* w_strides [[buffer(12)]],
1316+
const constant size_t* s_strides [[buffer(13)]],
1317+
const constant size_t* b_strides [[buffer(14)]],
1318+
const constant int& final_block_size [[buffer(15)]],
1319+
uint3 tid [[threadgroup_position_in_grid]],
1320+
uint simd_gid [[simdgroup_index_in_threadgroup]],
1321+
uint simd_lid [[thread_index_in_simdgroup]]) {
1322+
adjust_matrix_offsets<T>(
1323+
x,
1324+
w,
1325+
scales,
1326+
biases,
1327+
y,
1328+
out_vec_size,
1329+
x_batch_ndims,
1330+
x_shape,
1331+
x_strides,
1332+
w_batch_ndims,
1333+
w_shape,
1334+
w_strides,
1335+
s_strides,
1336+
b_strides,
1337+
tid);
1338+
1339+
// When (in_vec_size % split_k != 0) the final block needs to be smaller
1340+
int in_vec_size_adj =
1341+
tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
1342+
1343+
qvm_impl<T, group_size, bits>(
1344+
w,
1345+
scales,
1346+
biases,
1347+
x,
1348+
y,
1349+
in_vec_size_adj,
1350+
out_vec_size,
1351+
tid,
1352+
simd_gid,
1353+
simd_lid);
1354+
}
1355+
13011356
template <
13021357
typename T,
13031358
const int group_size,

mlx/backend/metal/kernels/quantized.metal

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,15 @@
5151
D, \
5252
batched)
5353

54+
#define instantiate_quantized_split_k(name, type, group_size, bits, split_k) \
55+
instantiate_kernel( \
56+
#name "_" #type "_gs_" #group_size "_b_" #bits "_spk_" #split_k, \
57+
name, \
58+
type, \
59+
group_size, \
60+
bits, \
61+
split_k)
62+
5463
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
5564
instantiate_quantized_batched(name, type, group_size, bits, 1) \
5665
instantiate_quantized_batched(name, type, group_size, bits, 0)
@@ -84,11 +93,16 @@
8493
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 1) \
8594
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 0)
8695

96+
#define instantiate_quantized_all_splitk(type, group_size, bits) \
97+
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
98+
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
99+
87100
#define instantiate_quantized_funcs(type, group_size, bits) \
88101
instantiate_quantized_all_single(type, group_size, bits) \
89102
instantiate_quantized_all_batched(type, group_size, bits) \
90103
instantiate_quantized_all_aligned(type, group_size, bits) \
91-
instantiate_quantized_all_quad(type, group_size, bits)
104+
instantiate_quantized_all_quad(type, group_size, bits) \
105+
instantiate_quantized_all_splitk(type, group_size, bits)
92106

93107
#define instantiate_quantized_types(group_size, bits) \
94108
instantiate_quantized_funcs(float, group_size, bits) \

mlx/backend/metal/quantized.cpp

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "mlx/backend/metal/copy.h"
77
#include "mlx/backend/metal/device.h"
88
#include "mlx/backend/metal/kernels.h"
9+
#include "mlx/backend/metal/reduce.h"
910
#include "mlx/backend/metal/utils.h"
1011
#include "mlx/fast_primitives.h"
1112
#include "mlx/primitives.h"
@@ -148,6 +149,125 @@ void launch_qmm(
148149
d.add_temporaries(std::move(copies), s.index);
149150
}
150151

152+
void qvm_split_k(
153+
const std::vector<array>& inputs,
154+
array& out,
155+
int group_size,
156+
int bits,
157+
int D,
158+
int O,
159+
int B,
160+
int N,
161+
const Stream& s) {
162+
int split_k = D > 8192 ? 32 : 8;
163+
int split_D = (D + split_k - 1) / split_k;
164+
N *= split_k;
165+
166+
int bo = 64;
167+
int bd = 32;
168+
MTL::Size group_dims = MTL::Size(bd, 2, 1);
169+
MTL::Size grid_dims = MTL::Size(O / bo, B, N);
170+
171+
auto& x_pre = inputs[0];
172+
auto& w_pre = inputs[1];
173+
auto& scales_pre = inputs[2];
174+
auto& biases_pre = inputs[3];
175+
176+
// Ensure that the last two dims are row contiguous.
177+
// TODO: Check if we really need this for x as well...
178+
std::vector<array> copies;
179+
auto ensure_row_contiguous_last_dims = [&copies, &s](const array& arr) {
180+
auto stride_0 = arr.strides()[arr.ndim() - 2];
181+
auto stride_1 = arr.strides()[arr.ndim() - 1];
182+
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
183+
return arr;
184+
} else {
185+
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
186+
copy_gpu(arr, arr_copy, CopyType::General, s);
187+
copies.push_back(arr_copy);
188+
return arr_copy;
189+
}
190+
};
191+
auto x = ensure_row_contiguous_last_dims(x_pre);
192+
auto w = ensure_row_contiguous_last_dims(w_pre);
193+
auto scales = ensure_row_contiguous_last_dims(scales_pre);
194+
auto biases = ensure_row_contiguous_last_dims(biases_pre);
195+
196+
int x_batch_ndims = x.ndim() - 2;
197+
auto x_shape = x.shape();
198+
auto x_strides = x.strides();
199+
int w_batch_ndims = w.ndim() - 2;
200+
auto w_shape = w.shape();
201+
auto w_strides = w.strides();
202+
auto s_strides = scales.strides();
203+
auto b_strides = biases.strides();
204+
205+
// Add split_k dim with reshapes
206+
x_shape.insert(x_shape.end() - 2, split_k);
207+
x_shape.back() /= split_k;
208+
x_strides.insert(x_strides.end() - 2, split_D);
209+
x_strides[x.ndim() - 1] = split_D;
210+
x_batch_ndims += 1;
211+
212+
w_shape.insert(w_shape.end() - 2, split_k);
213+
w_shape[w.ndim() - 1] /= split_k;
214+
w_strides.insert(w_strides.end() - 2, split_D * w.shape(-1));
215+
w_batch_ndims += 1;
216+
s_strides.insert(s_strides.end() - 2, split_D * scales.shape(-1));
217+
b_strides.insert(b_strides.end() - 2, split_D * biases.shape(-1));
218+
219+
int final_block_size = D - (split_k - 1) * split_D;
220+
221+
auto& d = metal::device(s.device);
222+
223+
auto temp_shape = out.shape();
224+
temp_shape.insert(temp_shape.end() - 2, split_k);
225+
array intermediate(temp_shape, x.dtype(), nullptr, {});
226+
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
227+
d.add_temporary(intermediate, s.index);
228+
229+
std::ostringstream kname;
230+
auto type_string = get_type_string(x.dtype());
231+
kname << "qvm_split_k" << "_" << type_string << "_gs_" << group_size << "_b_"
232+
<< bits << "_spk_" << split_k;
233+
auto template_def = get_template_definition(
234+
kname.str(), "qvm_split_k", type_string, group_size, bits, split_k);
235+
236+
// Encode and dispatch kernel
237+
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
238+
auto& compute_encoder = d.get_command_encoder(s.index);
239+
compute_encoder->setComputePipelineState(kernel);
240+
241+
compute_encoder.set_input_array(w, 0);
242+
compute_encoder.set_input_array(scales, 1);
243+
compute_encoder.set_input_array(biases, 2);
244+
compute_encoder.set_input_array(x, 3);
245+
compute_encoder.set_output_array(intermediate, 4);
246+
compute_encoder->setBytes(&split_D, sizeof(int), 5);
247+
compute_encoder->setBytes(&O, sizeof(int), 6);
248+
249+
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 7);
250+
set_vector_bytes(compute_encoder, x_shape, 8);
251+
set_vector_bytes(compute_encoder, x_strides, 9);
252+
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 10);
253+
set_vector_bytes(compute_encoder, w_shape, 11);
254+
set_vector_bytes(compute_encoder, w_strides, 12);
255+
set_vector_bytes(compute_encoder, s_strides, 13);
256+
set_vector_bytes(compute_encoder, b_strides, 14);
257+
compute_encoder->setBytes(&final_block_size, sizeof(int), 15);
258+
259+
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
260+
d.add_temporaries(std::move(copies), s.index);
261+
262+
int axis = intermediate.ndim() - 3;
263+
ReductionPlan plan(
264+
ReductionOpType::ContiguousStridedReduce,
265+
{intermediate.shape(axis)},
266+
{intermediate.strides(axis)});
267+
strided_reduce_general_dispatch(
268+
intermediate, out, "sum", plan, {axis}, compute_encoder, d, s);
269+
}
270+
151271
void qmm_op(
152272
const std::vector<array>& inputs,
153273
array& out,
@@ -211,7 +331,9 @@ void qmm_op(
211331
aligned = true;
212332
}
213333
} else {
214-
if (B < 4) {
334+
if (B < 4 && D >= 1024 && !gather) {
335+
return qvm_split_k(inputs, out, group_size, bits, D, O, B, N, s);
336+
} else if (B < 4) {
215337
name += "qvm";
216338
int bo = 64;
217339
int bd = 32;

python/tests/test_quantized.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,31 @@ def test_qvm(self):
163163
self.assertEqual(y_q.shape, y_hat.shape)
164164
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
165165

166+
def test_qvm_splitk(self):
167+
key = mx.random.key(0)
168+
k1, k2 = mx.random.split(key)
169+
tests = product(
170+
[128, 64, 32], # group_size
171+
[2, 4, 8], # bits
172+
[128], # M
173+
[16384], # N
174+
[1, 3], # B
175+
)
176+
for group_size, bits, M, N, B in tests:
177+
with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits):
178+
x_shape = (1, N) if B == 0 else (B, 1, N)
179+
w_shape = (N, M) if B == 0 else (B, N, M)
180+
x = mx.random.normal(shape=x_shape, key=k1)
181+
w = mx.random.normal(shape=w_shape, key=k2)
182+
w_q, scales, biases = mx.quantize(w, group_size, bits)
183+
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
184+
y_q = mx.quantized_matmul(
185+
x, w_q, scales, biases, False, group_size, bits
186+
)
187+
y_hat = x @ w_hat
188+
self.assertEqual(y_q.shape, y_hat.shape)
189+
self.assertLess((y_q - y_hat).abs().max(), 2e-3)
190+
166191
def test_throw(self):
167192
x = mx.random.normal(shape=(10, 512))
168193
w = mx.random.normal(shape=(32, 512))

0 commit comments

Comments
 (0)