Skip to content

Commit 5f9ba30

Browse files
authored
Fix qmm_t for unaligned cases (#923)
1 parent 46caf0b commit 5f9ba30

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

mlx/backend/metal/kernels/quantized.metal

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
520520
const int K_g = K / group_size;
521521
const int y_row = tid.y * BM;
522522
const int y_col = tid.x * BN;
523+
523524
x += y_row * K;
524525
w += y_col * K_w;
525526
scales += y_col * K_g;
@@ -572,7 +573,10 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
572573
const device uint32_t * w_local = w + offset_row * K_w + offset_col;
573574
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
574575

575-
if (y_row + offset_row < N) {
576+
// y_col corresponds to the row of the weight matrix and added to
577+
// offset_row it should be less than the total number of rows
578+
// otherwise skip.
579+
if (y_col + offset_row < N) {
576580
uint32_t wi = *w_local;
577581
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
578582
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];

python/tests/test_quantized.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,16 @@ def test_non_multiples(self):
229229
self.assertEqual(y_q.shape, y_hat.shape)
230230
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
231231

232+
# Test with larger than 128 unaligned sizes
233+
w = mx.random.normal(shape=(99, 256))
234+
w_q, scales, biases = mx.quantize(w)
235+
w_hat = mx.dequantize(w_q, scales, biases)
236+
x = mx.random.normal(shape=(129, 256))
237+
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
238+
y_hat = x @ w_hat.T
239+
self.assertEqual(y_q.shape, y_hat.shape)
240+
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
241+
232242

233243
if __name__ == "__main__":
234244
unittest.main()

0 commit comments

Comments
 (0)