@@ -54,6 +54,17 @@ mx::array quantized_matmul(const mx::array &scales, // Input array scale
54
54
if (!transpose_b) {
55
55
throw std::runtime_error (" quantized_matmul: b must be transposed" );
56
56
}
57
+
58
+ if (scales.shape () != biases.shape ()) {
59
+ throw std::runtime_error (" quantized_matmul: scales and biases must have the same shape" );
60
+ }
61
+ if (b.shape ()[0 ] != scales.shape ()[0 ]) {
62
+ throw std::runtime_error (" quantized_matmul: b must have the same number of rows as scales" );
63
+ }
64
+ if (b.shape ()[1 ] != scales.shape ()[1 ] * group_size / 8 ) {
65
+ throw std::runtime_error (" quantized_matmul: a must have the same number of columns as scales" );
66
+ }
67
+
57
68
return mx::array (
58
69
/* const mx::Shape& shape = */ out_shape,
59
70
/* mx::Dtype dtype = */ mx::float16,
@@ -73,14 +84,11 @@ void quantized_matmul_impl(const mx::array &scales, const mx::array &biases, con
73
84
encoder.set_input_array (b);
74
85
encoder.set_output_array (out);
75
86
76
- if (scales. shape () != biases. shape () ) {
77
- throw std::runtime_error (" quantized_matmul: scales and biases must have the same shape " );
87
+ if (!a. flags (). row_contiguous ) {
88
+ throw std::runtime_error (" quantized_matmul: a must be contiguous " );
78
89
}
79
- if (b.shape ()[0 ] != scales.shape ()[0 ]) {
80
- throw std::runtime_error (" quantized_matmul: b must have the same number of rows as scales" );
81
- }
82
- if (b.shape ()[1 ] != scales.shape ()[1 ] * group_size / 8 ) {
83
- throw std::runtime_error (" quantized_matmul: a must have the same number of columns as scales" );
90
+ if (!b.flags ().row_contiguous ) {
91
+ throw std::runtime_error (" quantized_matmul: b must be contiguous" );
84
92
}
85
93
86
94
// Launch the CPU kernel
@@ -100,32 +108,32 @@ void quantized_matmul_impl(const mx::array &scales, const mx::array &biases, con
100
108
uint32_t item_mask = (1 << bits) - 1 ;
101
109
for (int i = 0 ; i < M; i++) {
102
110
for (int k = 0 ; k < K; k++) {
111
+ float sum = 0 ;
103
112
for (int group_idx = 0 ; group_idx < group_per_row; group_idx++) {
104
113
int64_t scales_loc =
105
- mx::elem_to_loc (k * N / group_size + group_idx, scales.shape (), scales.strides ());
114
+ mx::elem_to_loc (k * group_per_row + group_idx, scales.shape (), scales.strides ());
106
115
int64_t biases_loc =
107
- mx::elem_to_loc (k * N / group_size + group_idx, biases.shape (), biases.strides ());
108
- float16_t sum = 0 ;
116
+ mx::elem_to_loc (k * group_per_row + group_idx, biases.shape (), biases.strides ());
109
117
float16_t scale = scales_ptr[scales_loc];
110
118
float16_t bias = biases_ptr[biases_loc];
119
+ int64_t b_loc = mx::elem_to_loc ((k * N + group_idx * group_size) / 8 , b.shape (), b.strides ());
120
+ int64_t a_loc = mx::elem_to_loc (i * N + group_idx * group_size, a.shape (), a.strides ());
111
121
const int packs_per_item = 32 / bits;
112
122
for (int item_idx = 0 ; item_idx < group_size; item_idx += packs_per_item) {
113
- int64_t b_loc =
114
- mx::elem_to_loc ((k * N + group_idx * group_size + item_idx) / 8 , b.shape (), b.strides ());
115
123
uint32_t b_val = b_ptr[b_loc];
116
124
uint8_t *b_bytes = reinterpret_cast <uint8_t *>(&b_val);
117
125
for (int pack_idx = 0 ; pack_idx < packs_per_item; pack_idx++) {
118
- int64_t a_loc = mx::elem_to_loc (i * N + group_idx * group_size + item_idx + pack_idx,
119
- a.shape (), a.strides ());
120
126
uint8_t item_val = (b_bytes[pack_idx / 2 ] >> ((pack_idx % 2 ) * bits)) & item_mask;
121
- float16_t b = static_cast <float16_t >(item_val) * scale + bias;
122
- float16_t a = a_ptr[a_loc];
127
+ float b = static_cast <float >(item_val) * scale + bias;
128
+ float a = a_ptr[a_loc];
123
129
sum += a * b;
130
+ a_loc += 1 ;
124
131
}
132
+ b_loc += 1 ;
125
133
}
126
- int64_t out_loc = mx::elem_to_loc (i * K + k, out_shape, out_strides);
127
- out_ptr[out_loc] = sum;
128
134
}
135
+ int64_t out_loc = mx::elem_to_loc (i * K + k, out_shape, out_strides);
136
+ out_ptr[out_loc] = static_cast <float16_t >(sum);
129
137
}
130
138
}
131
139
});
@@ -142,8 +150,65 @@ void QuantizedMatmul::eval_cpu(const std::vector<mx::array> &inputs, std::vector
142
150
quantized_matmul_impl (scales, biases, a, b, out, group_size_, bits_, stream ());
143
151
}
144
152
145
- void QuantizedMatmul::eval_gpu (const std::vector<mx::array> &inputs, std::vector<mx::array> &out) {
146
- throw std::runtime_error (" QuantizedMatmul has no GPU implementation." );
153
+ void load_library (mx::Device d, const char * path) {
154
+ auto &md = mx::metal::device (d);
155
+ md.register_library (" tiny_llm_ext_ref" , path);
156
+ }
157
+
158
+ void QuantizedMatmul::eval_gpu (const std::vector<mx::array> &inputs, std::vector<mx::array> &outputs) {
159
+ auto &scales = inputs[0 ];
160
+ auto &biases = inputs[1 ];
161
+ auto &a = inputs[2 ];
162
+ auto &b = inputs[3 ];
163
+ auto &out = outputs[0 ];
164
+
165
+ auto &s = stream ();
166
+ auto &d = mx::metal::device (s.device );
167
+ out.set_data (mx::allocator::malloc (out.nbytes ()));
168
+
169
+ // Make a kernel from this metal library
170
+ auto kernel = d.get_kernel (" quantized_matmul_w4a16_g64" , " tiny_llm_ext_ref" );
171
+
172
+ // Prepare to encode kernel
173
+ auto &compute_encoder = d.get_command_encoder (s.index );
174
+ compute_encoder.set_compute_pipeline_state (kernel);
175
+
176
+ // Kernel parameters are registered with buffer indices corresponding to
177
+ // those in the kernel declaration at axpby.metal
178
+ int ndim = out.ndim ();
179
+
180
+ // Encode input arrays to kernel
181
+ compute_encoder.set_input_array (scales, 0 );
182
+ compute_encoder.set_input_array (biases, 1 );
183
+ compute_encoder.set_input_array (a, 2 );
184
+ compute_encoder.set_input_array (b, 3 );
185
+ // Encode output arrays to kernel
186
+ compute_encoder.set_output_array (out, 4 );
187
+
188
+
189
+ if (!a.flags ().row_contiguous ) {
190
+ throw std::runtime_error (" quantized_matmul: a must be contiguous" );
191
+ }
192
+ if (!b.flags ().row_contiguous ) {
193
+ throw std::runtime_error (" quantized_matmul: b must be contiguous" );
194
+ }
195
+
196
+ int M = a.shape ()[0 ];
197
+ int N = a.shape ()[1 ];
198
+ int K = b.shape ()[0 ];
199
+
200
+ // Encode matrix parameters
201
+ compute_encoder.set_bytes (M, 5 );
202
+ compute_encoder.set_bytes (N, 6 );
203
+ compute_encoder.set_bytes (K, 7 );
204
+
205
+ size_t tgp_size = kernel->maxTotalThreadsPerThreadgroup ();
206
+ MTL::Size num_threadgroups = MTL::Size ((M * K + tgp_size - 1 ) / tgp_size, 1 , 1 );
207
+ MTL::Size num_threads_per_group = MTL::Size (tgp_size, 1 , 1 );
208
+
209
+ // Launch the grid with the given number of threads divided among
210
+ // the given threadgroups
211
+ compute_encoder.dispatch_threadgroups (num_threadgroups, num_threads_per_group);
147
212
}
148
213
149
214
bool QuantizedMatmul::is_equivalent (const Primitive &other) const {
0 commit comments