Skip to content

Commit c7b0300

Browse files
authored
Fix batched qmv bug (#1758)
1 parent da8c885 commit c7b0300

File tree

2 files changed

+22
-13
lines changed

2 files changed

+22
-13
lines changed

mlx/backend/metal/kernels/quantized.h

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,13 +1323,14 @@ template <typename T, int group_size, int bits, int D, bool batched>
13231323
uint quad_gid [[quadgroup_index_in_threadgroup]],
13241324
uint quad_lid [[thread_index_in_quadgroup]]) {
13251325
if (batched) {
1326+
int M = x_shape[x_batch_ndims];
13261327
adjust_matrix_offsets<T>(
13271328
x,
13281329
w,
13291330
scales,
13301331
biases,
13311332
y,
1332-
out_vec_size,
1333+
out_vec_size * M,
13331334
x_batch_ndims,
13341335
x_shape,
13351336
x_strides,
@@ -1374,13 +1375,14 @@ template <typename T, int group_size, int bits, bool batched>
13741375
uint simd_gid [[simdgroup_index_in_threadgroup]],
13751376
uint simd_lid [[thread_index_in_simdgroup]]) {
13761377
if (batched) {
1378+
int M = x_shape[x_batch_ndims];
13771379
adjust_matrix_offsets<T>(
13781380
x,
13791381
w,
13801382
scales,
13811383
biases,
13821384
y,
1383-
out_vec_size,
1385+
out_vec_size * M,
13841386
x_batch_ndims,
13851387
x_shape,
13861388
x_strides,
@@ -1425,13 +1427,14 @@ template <typename T, const int group_size, const int bits, bool batched>
14251427
uint simd_gid [[simdgroup_index_in_threadgroup]],
14261428
uint simd_lid [[thread_index_in_simdgroup]]) {
14271429
if (batched) {
1430+
int M = x_shape[x_batch_ndims];
14281431
adjust_matrix_offsets<T>(
14291432
x,
14301433
w,
14311434
scales,
14321435
biases,
14331436
y,
1434-
out_vec_size,
1437+
out_vec_size * M,
14351438
x_batch_ndims,
14361439
x_shape,
14371440
x_strides,
@@ -1476,13 +1479,14 @@ template <typename T, const int group_size, const int bits, bool batched>
14761479
uint simd_gid [[simdgroup_index_in_threadgroup]],
14771480
uint simd_lid [[thread_index_in_simdgroup]]) {
14781481
if (batched) {
1482+
int M = x_shape[x_batch_ndims];
14791483
adjust_matrix_offsets<T>(
14801484
x,
14811485
w,
14821486
scales,
14831487
biases,
14841488
y,
1485-
out_vec_size,
1489+
out_vec_size * M,
14861490
x_batch_ndims,
14871491
x_shape,
14881492
x_strides,
@@ -1527,13 +1531,14 @@ template <typename T, const int group_size, const int bits, int split_k = 32>
15271531
uint3 tid [[threadgroup_position_in_grid]],
15281532
uint simd_gid [[simdgroup_index_in_threadgroup]],
15291533
uint simd_lid [[thread_index_in_simdgroup]]) {
1534+
int M = x_shape[x_batch_ndims];
15301535
adjust_matrix_offsets<T>(
15311536
x,
15321537
w,
15331538
scales,
15341539
biases,
15351540
y,
1536-
out_vec_size,
1541+
out_vec_size * M,
15371542
x_batch_ndims,
15381543
x_shape,
15391544
x_strides,
@@ -1706,6 +1711,7 @@ template <typename T, int group_size, int bits>
17061711
uint3 tid [[threadgroup_position_in_grid]],
17071712
uint simd_gid [[simdgroup_index_in_threadgroup]],
17081713
uint simd_lid [[thread_index_in_simdgroup]]) {
1714+
int M = x_shape[x_batch_ndims];
17091715
adjust_matrix_offsets<T>(
17101716
x,
17111717
w,
@@ -1714,7 +1720,7 @@ template <typename T, int group_size, int bits>
17141720
lhs_indices,
17151721
rhs_indices,
17161722
y,
1717-
out_vec_size,
1723+
out_vec_size * M,
17181724
batch_ndims,
17191725
batch_shape,
17201726
lhs_strides,
@@ -1767,6 +1773,7 @@ template <typename T, int group_size, int bits>
17671773
uint3 tid [[threadgroup_position_in_grid]],
17681774
uint simd_gid [[simdgroup_index_in_threadgroup]],
17691775
uint simd_lid [[thread_index_in_simdgroup]]) {
1776+
int M = x_shape[x_batch_ndims];
17701777
adjust_matrix_offsets<T>(
17711778
x,
17721779
w,
@@ -1775,7 +1782,7 @@ template <typename T, int group_size, int bits>
17751782
lhs_indices,
17761783
rhs_indices,
17771784
y,
1778-
out_vec_size,
1785+
out_vec_size * M,
17791786
batch_ndims,
17801787
batch_shape,
17811788
lhs_strides,
@@ -1828,6 +1835,7 @@ template <typename T, int group_size, int bits>
18281835
uint3 tid [[threadgroup_position_in_grid]],
18291836
uint simd_gid [[simdgroup_index_in_threadgroup]],
18301837
uint simd_lid [[thread_index_in_simdgroup]]) {
1838+
int M = x_shape[x_batch_ndims];
18311839
adjust_matrix_offsets<T>(
18321840
x,
18331841
w,
@@ -1836,7 +1844,7 @@ template <typename T, int group_size, int bits>
18361844
lhs_indices,
18371845
rhs_indices,
18381846
y,
1839-
out_vec_size,
1847+
out_vec_size * M,
18401848
batch_ndims,
18411849
batch_shape,
18421850
lhs_strides,

python/tests/test_quantized.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,12 @@ def test_small_matrix(self):
212212
w_hat = mx.dequantize(w_q, scales, biases)
213213

214214
# Test qmv
215-
x = mx.random.normal(shape=(3, 1, 256))
216-
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
217-
y_hat = x @ mx.swapaxes(w_hat, -1, -2)
218-
self.assertEqual(y_q.shape, y_hat.shape)
219-
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
215+
for shape in [(3, 1, 256), (3, 4, 256)]:
216+
x = mx.random.normal(shape=shape)
217+
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
218+
y_hat = x @ mx.swapaxes(w_hat, -1, -2)
219+
self.assertEqual(y_q.shape, y_hat.shape)
220+
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
220221

221222
# Test qmm_t
222223
x = mx.random.normal(shape=(3, 10, 256))

0 commit comments

Comments
 (0)