Skip to content

Commit e7f5059

Browse files
authored
Support for quantized matmul with w and w^T (#349)
* Add the metal qvm implementation * Add qmm_n * Add gradient wrt to input for quantized_matmul
1 parent d7ac050 commit e7f5059

File tree

12 files changed

+728
-203
lines changed

12 files changed

+728
-203
lines changed

benchmarks/python/comparative/bench_mlx.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import math
55
import os
66
import time
7+
from functools import partial
78

89
import mlx.core as mx
910
import mlx.nn as nn
@@ -59,15 +60,23 @@ def matmul(x, y):
5960
mx.eval(ys)
6061

6162

62-
def quant_matmul(x, w, s, b):
63-
groups = x.shape[-1] // s.shape[-1]
64-
width = 32 // (x.shape[-1] // w.shape[0])
63+
def _quant_matmul(x, w, s, b, group_size, bits):
6564
ys = []
6665
for i in range(10):
67-
ys.append(mx.quantized_matmul(x, w, s, b, groups=groups, width=width))
66+
ys.append(mx.quantized_matmul(x, w, s, b, group_size=group_size, bits=bits))
6867
mx.eval(ys)
6968

7069

70+
quant_matmul = {
71+
"quant_matmul_64_2": partial(_quant_matmul, group_size=64, bits=2),
72+
"quant_matmul_64_4": partial(_quant_matmul, group_size=64, bits=4),
73+
"quant_matmul_64_8": partial(_quant_matmul, group_size=64, bits=8),
74+
"quant_matmul_128_2": partial(_quant_matmul, group_size=128, bits=2),
75+
"quant_matmul_128_4": partial(_quant_matmul, group_size=128, bits=4),
76+
"quant_matmul_128_8": partial(_quant_matmul, group_size=128, bits=8),
77+
}
78+
79+
7180
def conv1d(x, y):
7281
ys = []
7382
for i in range(10):
@@ -356,8 +365,8 @@ def selu(x):
356365
elif args.benchmark == "matmul":
357366
print(bench(matmul, *xs))
358367

359-
elif args.benchmark == "quant_matmul":
360-
print(bench(quant_matmul, *xs))
368+
elif args.benchmark.startswith("quant_matmul"):
369+
print(bench(quant_matmul[args.benchmark], *xs))
361370

362371
elif args.benchmark == "linear":
363372
print(bench(linear, *xs))

mlx/backend/accelerate/quantized.cpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,20 +76,16 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
7676
auto& scales = inputs[2];
7777
auto& biases = inputs[3];
7878

79-
if (w.strides()[0] != 1) {
80-
throw std::runtime_error("The quantized weight should be transposed");
81-
}
82-
83-
if (!x.flags().row_contiguous || !scales.flags().row_contiguous ||
84-
!biases.flags().row_contiguous) {
85-
throw std::runtime_error("x, scales and biases should be row contiguous.");
86-
}
79+
bool condition =
80+
(transpose_ && x.flags().row_contiguous && w.flags().row_contiguous &&
81+
scales.flags().row_contiguous && biases.flags().row_contiguous &&
82+
x.dtype() == float32 && bits_ == 4 && group_size_ == 64);
8783

88-
if (x.dtype() == float32 && bits_ == 4 && group_size_ == 64) {
84+
if (condition) {
8985
out.set_data(allocator::malloc_or_wait(out.nbytes()));
9086
int K = x.shape(-1);
9187
int M = x.size() / K;
92-
int N = w.shape(1);
88+
int N = out.shape(-1);
9389
_qmm_t_4_64(
9490
out.data<float>(),
9591
x.data<float>(),

mlx/backend/common/quantized.cpp

Lines changed: 113 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,62 @@
11
// Copyright © 2023 Apple Inc.
22

33
#include <cassert>
4+
#include <iostream>
45

6+
#include "mlx/backend/metal/copy.h"
57
#include "mlx/primitives.h"
68

79
namespace mlx::core {
810

911
namespace {
1012

13+
template <typename T, int bits, int group_size>
14+
void _qmm(
15+
T* result,
16+
const T* x,
17+
const uint32_t* w,
18+
const T* scales,
19+
const T* biases,
20+
int M,
21+
int N,
22+
int K) {
23+
constexpr int bitmask = (1 << bits) - 1;
24+
constexpr int pack_factor = 32 / bits;
25+
constexpr int packs_in_group = group_size / pack_factor;
26+
const int Ng = N / group_size;
27+
const int Nw = N / pack_factor;
28+
29+
for (int m = 0; m < M; m++) {
30+
const uint32_t* w_local = w;
31+
const T* scales_local = scales;
32+
const T* biases_local = biases;
33+
34+
std::fill(result, result + N, 0);
35+
36+
for (int k = 0; k < K; k++) {
37+
T* result_local = result;
38+
T xi = *x++;
39+
40+
for (int n = 0; n < N; n += group_size) {
41+
T scale = *scales_local++;
42+
T bias = *biases_local++;
43+
for (int ng = 0; ng < packs_in_group; ng++) {
44+
uint32_t wi = *w_local++;
45+
46+
#pragma clang loop unroll(full)
47+
for (int p = 0; p < pack_factor; p++) {
48+
(*result_local++) +=
49+
xi * (scale * static_cast<T>(wi & bitmask) + bias);
50+
wi >>= bits;
51+
}
52+
}
53+
}
54+
}
55+
56+
result += N;
57+
}
58+
}
59+
1160
template <typename T, int bits, int group_size>
1261
void _qmm_t(
1362
T* result,
@@ -55,7 +104,7 @@ void _qmm_t(
55104
}
56105

57106
template <typename T>
58-
void _qmm_t_dispatch_typed(
107+
void _qmm_dispatch_typed(
59108
T* result,
60109
const T* x,
61110
const uint32_t* w,
@@ -65,30 +114,55 @@ void _qmm_t_dispatch_typed(
65114
int N,
66115
int K,
67116
int group_size,
68-
int bits) {
117+
int bits,
118+
bool transposed_w) {
69119
switch (bits) {
70120
case 2: {
71121
switch (group_size) {
72122
case 64:
73-
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K);
123+
if (transposed_w) {
124+
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K);
125+
} else {
126+
return _qmm<T, 2, 64>(result, x, w, scales, biases, M, N, K);
127+
}
74128
case 128:
75-
return _qmm_t<T, 2, 128>(result, x, w, scales, biases, M, N, K);
129+
if (transposed_w) {
130+
return _qmm_t<T, 2, 128>(result, x, w, scales, biases, M, N, K);
131+
} else {
132+
return _qmm<T, 2, 128>(result, x, w, scales, biases, M, N, K);
133+
}
76134
}
77135
}
78136
case 4: {
79137
switch (group_size) {
80138
case 64:
81-
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
139+
if (transposed_w) {
140+
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
141+
} else {
142+
return _qmm<T, 4, 64>(result, x, w, scales, biases, M, N, K);
143+
}
82144
case 128:
83-
return _qmm_t<T, 4, 128>(result, x, w, scales, biases, M, N, K);
145+
if (transposed_w) {
146+
return _qmm_t<T, 4, 128>(result, x, w, scales, biases, M, N, K);
147+
} else {
148+
return _qmm<T, 4, 128>(result, x, w, scales, biases, M, N, K);
149+
}
84150
}
85151
}
86152
case 8: {
87153
switch (group_size) {
88154
case 64:
89-
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);
155+
if (transposed_w) {
156+
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);
157+
} else {
158+
return _qmm<T, 8, 64>(result, x, w, scales, biases, M, N, K);
159+
}
90160
case 128:
91-
return _qmm_t<T, 8, 128>(result, x, w, scales, biases, M, N, K);
161+
if (transposed_w) {
162+
return _qmm_t<T, 8, 128>(result, x, w, scales, biases, M, N, K);
163+
} else {
164+
return _qmm<T, 8, 128>(result, x, w, scales, biases, M, N, K);
165+
}
92166
}
93167
}
94168
}
@@ -100,21 +174,22 @@ void _qmm_t_dispatch_typed(
100174
throw std::invalid_argument(msg.str());
101175
}
102176

103-
void _qmm_t_dispatch(
177+
void _qmm_dispatch(
104178
array out,
105179
const array& x,
106180
const array& w,
107181
const array& scales,
108182
const array& biases,
109183
int bits,
110-
int group_size) {
184+
int group_size,
185+
bool transposed_w) {
111186
int K = x.shape(-1);
112187
int M = x.size() / K;
113-
int N = w.shape(1);
188+
int N = out.shape(-1);
114189

115190
switch (x.dtype()) {
116191
case float32:
117-
_qmm_t_dispatch_typed<float>(
192+
_qmm_dispatch_typed<float>(
118193
out.data<float>(),
119194
x.data<float>(),
120195
w.data<uint32_t>(),
@@ -124,10 +199,11 @@ void _qmm_t_dispatch(
124199
N,
125200
K,
126201
bits,
127-
group_size);
202+
group_size,
203+
transposed_w);
128204
break;
129205
case float16:
130-
_qmm_t_dispatch_typed<float16_t>(
206+
_qmm_dispatch_typed<float16_t>(
131207
out.data<float16_t>(),
132208
x.data<float16_t>(),
133209
w.data<uint32_t>(),
@@ -137,10 +213,11 @@ void _qmm_t_dispatch(
137213
N,
138214
K,
139215
bits,
140-
group_size);
216+
group_size,
217+
transposed_w);
141218
break;
142219
case bfloat16:
143-
_qmm_t_dispatch_typed<bfloat16_t>(
220+
_qmm_dispatch_typed<bfloat16_t>(
144221
out.data<bfloat16_t>(),
145222
x.data<bfloat16_t>(),
146223
w.data<uint32_t>(),
@@ -150,7 +227,8 @@ void _qmm_t_dispatch(
150227
N,
151228
K,
152229
bits,
153-
group_size);
230+
group_size,
231+
transposed_w);
154232
break;
155233
default:
156234
throw std::invalid_argument(
@@ -163,22 +241,28 @@ void _qmm_t_dispatch(
163241
void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
164242
assert(inputs.size() == 4);
165243

166-
auto& x = inputs[0];
167-
auto& w = inputs[1];
168-
auto& scales = inputs[2];
169-
auto& biases = inputs[3];
244+
auto& x_pre = inputs[0];
245+
auto& w_pre = inputs[1];
246+
auto& scales_pre = inputs[2];
247+
auto& biases_pre = inputs[3];
170248

171-
if (w.strides()[0] != 1) {
172-
throw std::runtime_error("The quantized weight should be transposed");
173-
}
249+
auto ensure_row_contiguous = [](const array& arr) {
250+
if (arr.flags().row_contiguous) {
251+
return arr;
252+
} else {
253+
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
254+
copy(arr, arr_copy, CopyType::General);
255+
return arr_copy;
256+
}
257+
};
174258

175-
if (!x.flags().row_contiguous || !scales.flags().row_contiguous ||
176-
!biases.flags().row_contiguous) {
177-
throw std::runtime_error("x, scales and biases should be row contiguous.");
178-
}
259+
auto x = ensure_row_contiguous(x_pre);
260+
auto w = ensure_row_contiguous(w_pre);
261+
auto scales = ensure_row_contiguous(scales_pre);
262+
auto biases = ensure_row_contiguous(biases_pre);
179263

180264
out.set_data(allocator::malloc_or_wait(out.nbytes()));
181-
_qmm_t_dispatch(out, x, w, scales, biases, group_size_, bits_);
265+
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
182266
}
183267

184268
} // namespace mlx::core

0 commit comments

Comments
 (0)