11// Copyright © 2023 Apple Inc.
22
33#include < cassert>
4+ #include < iostream>
45
6+ #include " mlx/backend/metal/copy.h"
57#include " mlx/primitives.h"
68
79namespace mlx ::core {
810
911namespace {
1012
13+ template <typename T, int bits, int group_size>
14+ void _qmm (
15+ T* result,
16+ const T* x,
17+ const uint32_t * w,
18+ const T* scales,
19+ const T* biases,
20+ int M,
21+ int N,
22+ int K) {
23+ constexpr int bitmask = (1 << bits) - 1 ;
24+ constexpr int pack_factor = 32 / bits;
25+ constexpr int packs_in_group = group_size / pack_factor;
26+ const int Ng = N / group_size;
27+ const int Nw = N / pack_factor;
28+
29+ for (int m = 0 ; m < M; m++) {
30+ const uint32_t * w_local = w;
31+ const T* scales_local = scales;
32+ const T* biases_local = biases;
33+
34+ std::fill (result, result + N, 0 );
35+
36+ for (int k = 0 ; k < K; k++) {
37+ T* result_local = result;
38+ T xi = *x++;
39+
40+ for (int n = 0 ; n < N; n += group_size) {
41+ T scale = *scales_local++;
42+ T bias = *biases_local++;
43+ for (int ng = 0 ; ng < packs_in_group; ng++) {
44+ uint32_t wi = *w_local++;
45+
46+ #pragma clang loop unroll(full)
47+ for (int p = 0 ; p < pack_factor; p++) {
48+ (*result_local++) +=
49+ xi * (scale * static_cast <T>(wi & bitmask) + bias);
50+ wi >>= bits;
51+ }
52+ }
53+ }
54+ }
55+
56+ result += N;
57+ }
58+ }
59+
1160template <typename T, int bits, int group_size>
1261void _qmm_t (
1362 T* result,
@@ -55,7 +104,7 @@ void _qmm_t(
55104}
56105
57106template <typename T>
58- void _qmm_t_dispatch_typed (
107+ void _qmm_dispatch_typed (
59108 T* result,
60109 const T* x,
61110 const uint32_t * w,
@@ -65,30 +114,55 @@ void _qmm_t_dispatch_typed(
65114 int N,
66115 int K,
67116 int group_size,
68- int bits) {
117+ int bits,
118+ bool transposed_w) {
69119 switch (bits) {
70120 case 2 : {
71121 switch (group_size) {
72122 case 64 :
73- return _qmm_t <T, 2 , 64 >(result, x, w, scales, biases, M, N, K);
123+ if (transposed_w) {
124+ return _qmm_t <T, 2 , 64 >(result, x, w, scales, biases, M, N, K);
125+ } else {
126+ return _qmm<T, 2 , 64 >(result, x, w, scales, biases, M, N, K);
127+ }
74128 case 128 :
75- return _qmm_t <T, 2 , 128 >(result, x, w, scales, biases, M, N, K);
129+ if (transposed_w) {
130+ return _qmm_t <T, 2 , 128 >(result, x, w, scales, biases, M, N, K);
131+ } else {
132+ return _qmm<T, 2 , 128 >(result, x, w, scales, biases, M, N, K);
133+ }
76134 }
77135 }
78136 case 4 : {
79137 switch (group_size) {
80138 case 64 :
81- return _qmm_t <T, 4 , 64 >(result, x, w, scales, biases, M, N, K);
139+ if (transposed_w) {
140+ return _qmm_t <T, 4 , 64 >(result, x, w, scales, biases, M, N, K);
141+ } else {
142+ return _qmm<T, 4 , 64 >(result, x, w, scales, biases, M, N, K);
143+ }
82144 case 128 :
83- return _qmm_t <T, 4 , 128 >(result, x, w, scales, biases, M, N, K);
145+ if (transposed_w) {
146+ return _qmm_t <T, 4 , 128 >(result, x, w, scales, biases, M, N, K);
147+ } else {
148+ return _qmm<T, 4 , 128 >(result, x, w, scales, biases, M, N, K);
149+ }
84150 }
85151 }
86152 case 8 : {
87153 switch (group_size) {
88154 case 64 :
89- return _qmm_t <T, 8 , 64 >(result, x, w, scales, biases, M, N, K);
155+ if (transposed_w) {
156+ return _qmm_t <T, 8 , 64 >(result, x, w, scales, biases, M, N, K);
157+ } else {
158+ return _qmm<T, 8 , 64 >(result, x, w, scales, biases, M, N, K);
159+ }
90160 case 128 :
91- return _qmm_t <T, 8 , 128 >(result, x, w, scales, biases, M, N, K);
161+ if (transposed_w) {
162+ return _qmm_t <T, 8 , 128 >(result, x, w, scales, biases, M, N, K);
163+ } else {
164+ return _qmm<T, 8 , 128 >(result, x, w, scales, biases, M, N, K);
165+ }
92166 }
93167 }
94168 }
@@ -100,21 +174,22 @@ void _qmm_t_dispatch_typed(
100174 throw std::invalid_argument (msg.str ());
101175}
102176
103- void _qmm_t_dispatch (
177+ void _qmm_dispatch (
104178 array out,
105179 const array& x,
106180 const array& w,
107181 const array& scales,
108182 const array& biases,
109183 int bits,
110- int group_size) {
184+ int group_size,
185+ bool transposed_w) {
111186 int K = x.shape (-1 );
112187 int M = x.size () / K;
113- int N = w .shape (1 );
188+ int N = out .shape (- 1 );
114189
115190 switch (x.dtype ()) {
116191 case float32:
117- _qmm_t_dispatch_typed <float >(
192+ _qmm_dispatch_typed <float >(
118193 out.data <float >(),
119194 x.data <float >(),
120195 w.data <uint32_t >(),
@@ -124,10 +199,11 @@ void _qmm_t_dispatch(
124199 N,
125200 K,
126201 bits,
127- group_size);
202+ group_size,
203+ transposed_w);
128204 break ;
129205 case float16:
130- _qmm_t_dispatch_typed <float16_t >(
206+ _qmm_dispatch_typed <float16_t >(
131207 out.data <float16_t >(),
132208 x.data <float16_t >(),
133209 w.data <uint32_t >(),
@@ -137,10 +213,11 @@ void _qmm_t_dispatch(
137213 N,
138214 K,
139215 bits,
140- group_size);
216+ group_size,
217+ transposed_w);
141218 break ;
142219 case bfloat16:
143- _qmm_t_dispatch_typed <bfloat16_t >(
220+ _qmm_dispatch_typed <bfloat16_t >(
144221 out.data <bfloat16_t >(),
145222 x.data <bfloat16_t >(),
146223 w.data <uint32_t >(),
@@ -150,7 +227,8 @@ void _qmm_t_dispatch(
150227 N,
151228 K,
152229 bits,
153- group_size);
230+ group_size,
231+ transposed_w);
154232 break ;
155233 default :
156234 throw std::invalid_argument (
@@ -163,22 +241,28 @@ void _qmm_t_dispatch(
163241void QuantizedMatmul::eval (const std::vector<array>& inputs, array& out) {
164242 assert (inputs.size () == 4 );
165243
166- auto & x = inputs[0 ];
167- auto & w = inputs[1 ];
168- auto & scales = inputs[2 ];
169- auto & biases = inputs[3 ];
244+ auto & x_pre = inputs[0 ];
245+ auto & w_pre = inputs[1 ];
246+ auto & scales_pre = inputs[2 ];
247+ auto & biases_pre = inputs[3 ];
170248
171- if (w.strides ()[0 ] != 1 ) {
172- throw std::runtime_error (" The quantized weight should be transposed" );
173- }
249+ auto ensure_row_contiguous = [](const array& arr) {
250+ if (arr.flags ().row_contiguous ) {
251+ return arr;
252+ } else {
253+ array arr_copy (arr.shape (), arr.dtype (), nullptr , {});
254+ copy (arr, arr_copy, CopyType::General);
255+ return arr_copy;
256+ }
257+ };
174258
175- if (!x. flags (). row_contiguous || !scales. flags (). row_contiguous ||
176- !biases. flags (). row_contiguous ) {
177- throw std::runtime_error ( " x, scales and biases should be row contiguous. " );
178- }
259+ auto x = ensure_row_contiguous (x_pre);
260+ auto w = ensure_row_contiguous (w_pre);
261+ auto scales = ensure_row_contiguous (scales_pre );
262+ auto biases = ensure_row_contiguous (biases_pre);
179263
180264 out.set_data (allocator::malloc_or_wait (out.nbytes ()));
181- _qmm_t_dispatch (out, x, w, scales, biases, group_size_, bits_);
265+ _qmm_dispatch (out, x, w, scales, biases, group_size_, bits_, transpose_ );
182266}
183267
184268} // namespace mlx::core
0 commit comments