|
6 | 6 | #include "mlx/backend/metal/copy.h" |
7 | 7 | #include "mlx/backend/metal/device.h" |
8 | 8 | #include "mlx/backend/metal/kernels.h" |
| 9 | +#include "mlx/backend/metal/reduce.h" |
9 | 10 | #include "mlx/backend/metal/utils.h" |
10 | 11 | #include "mlx/fast_primitives.h" |
11 | 12 | #include "mlx/primitives.h" |
@@ -148,6 +149,125 @@ void launch_qmm( |
148 | 149 | d.add_temporaries(std::move(copies), s.index); |
149 | 150 | } |
150 | 151 |
|
| 152 | +void qvm_split_k( |
| 153 | + const std::vector<array>& inputs, |
| 154 | + array& out, |
| 155 | + int group_size, |
| 156 | + int bits, |
| 157 | + int D, |
| 158 | + int O, |
| 159 | + int B, |
| 160 | + int N, |
| 161 | + const Stream& s) { |
| 162 | + int split_k = D > 8192 ? 32 : 8; |
| 163 | + int split_D = (D + split_k - 1) / split_k; |
| 164 | + N *= split_k; |
| 165 | + |
| 166 | + int bo = 64; |
| 167 | + int bd = 32; |
| 168 | + MTL::Size group_dims = MTL::Size(bd, 2, 1); |
| 169 | + MTL::Size grid_dims = MTL::Size(O / bo, B, N); |
| 170 | + |
| 171 | + auto& x_pre = inputs[0]; |
| 172 | + auto& w_pre = inputs[1]; |
| 173 | + auto& scales_pre = inputs[2]; |
| 174 | + auto& biases_pre = inputs[3]; |
| 175 | + |
| 176 | + // Ensure that the last two dims are row contiguous. |
| 177 | + // TODO: Check if we really need this for x as well... |
| 178 | + std::vector<array> copies; |
| 179 | + auto ensure_row_contiguous_last_dims = [&copies, &s](const array& arr) { |
| 180 | + auto stride_0 = arr.strides()[arr.ndim() - 2]; |
| 181 | + auto stride_1 = arr.strides()[arr.ndim() - 1]; |
| 182 | + if (stride_0 == arr.shape(-1) && stride_1 == 1) { |
| 183 | + return arr; |
| 184 | + } else { |
| 185 | + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); |
| 186 | + copy_gpu(arr, arr_copy, CopyType::General, s); |
| 187 | + copies.push_back(arr_copy); |
| 188 | + return arr_copy; |
| 189 | + } |
| 190 | + }; |
| 191 | + auto x = ensure_row_contiguous_last_dims(x_pre); |
| 192 | + auto w = ensure_row_contiguous_last_dims(w_pre); |
| 193 | + auto scales = ensure_row_contiguous_last_dims(scales_pre); |
| 194 | + auto biases = ensure_row_contiguous_last_dims(biases_pre); |
| 195 | + |
| 196 | + int x_batch_ndims = x.ndim() - 2; |
| 197 | + auto x_shape = x.shape(); |
| 198 | + auto x_strides = x.strides(); |
| 199 | + int w_batch_ndims = w.ndim() - 2; |
| 200 | + auto w_shape = w.shape(); |
| 201 | + auto w_strides = w.strides(); |
| 202 | + auto s_strides = scales.strides(); |
| 203 | + auto b_strides = biases.strides(); |
| 204 | + |
| 205 | + // Add split_k dim with reshapes |
| 206 | + x_shape.insert(x_shape.end() - 2, split_k); |
| 207 | + x_shape.back() /= split_k; |
| 208 | + x_strides.insert(x_strides.end() - 2, split_D); |
| 209 | + x_strides[x.ndim() - 1] = split_D; |
| 210 | + x_batch_ndims += 1; |
| 211 | + |
| 212 | + w_shape.insert(w_shape.end() - 2, split_k); |
| 213 | + w_shape[w.ndim() - 1] /= split_k; |
| 214 | + w_strides.insert(w_strides.end() - 2, split_D * w.shape(-1)); |
| 215 | + w_batch_ndims += 1; |
| 216 | + s_strides.insert(s_strides.end() - 2, split_D * scales.shape(-1)); |
| 217 | + b_strides.insert(b_strides.end() - 2, split_D * biases.shape(-1)); |
| 218 | + |
| 219 | + int final_block_size = D - (split_k - 1) * split_D; |
| 220 | + |
| 221 | + auto& d = metal::device(s.device); |
| 222 | + |
| 223 | + auto temp_shape = out.shape(); |
| 224 | + temp_shape.insert(temp_shape.end() - 2, split_k); |
| 225 | + array intermediate(temp_shape, x.dtype(), nullptr, {}); |
| 226 | + intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); |
| 227 | + d.add_temporary(intermediate, s.index); |
| 228 | + |
| 229 | + std::ostringstream kname; |
| 230 | + auto type_string = get_type_string(x.dtype()); |
| 231 | + kname << "qvm_split_k" << "_" << type_string << "_gs_" << group_size << "_b_" |
| 232 | + << bits << "_spk_" << split_k; |
| 233 | + auto template_def = get_template_definition( |
| 234 | + kname.str(), "qvm_split_k", type_string, group_size, bits, split_k); |
| 235 | + |
| 236 | + // Encode and dispatch kernel |
| 237 | + auto kernel = get_quantized_kernel(d, kname.str(), template_def); |
| 238 | + auto& compute_encoder = d.get_command_encoder(s.index); |
| 239 | + compute_encoder->setComputePipelineState(kernel); |
| 240 | + |
| 241 | + compute_encoder.set_input_array(w, 0); |
| 242 | + compute_encoder.set_input_array(scales, 1); |
| 243 | + compute_encoder.set_input_array(biases, 2); |
| 244 | + compute_encoder.set_input_array(x, 3); |
| 245 | + compute_encoder.set_output_array(intermediate, 4); |
| 246 | + compute_encoder->setBytes(&split_D, sizeof(int), 5); |
| 247 | + compute_encoder->setBytes(&O, sizeof(int), 6); |
| 248 | + |
| 249 | + compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 7); |
| 250 | + set_vector_bytes(compute_encoder, x_shape, 8); |
| 251 | + set_vector_bytes(compute_encoder, x_strides, 9); |
| 252 | + compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 10); |
| 253 | + set_vector_bytes(compute_encoder, w_shape, 11); |
| 254 | + set_vector_bytes(compute_encoder, w_strides, 12); |
| 255 | + set_vector_bytes(compute_encoder, s_strides, 13); |
| 256 | + set_vector_bytes(compute_encoder, b_strides, 14); |
| 257 | + compute_encoder->setBytes(&final_block_size, sizeof(int), 15); |
| 258 | + |
| 259 | + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); |
| 260 | + d.add_temporaries(std::move(copies), s.index); |
| 261 | + |
| 262 | + int axis = intermediate.ndim() - 3; |
| 263 | + ReductionPlan plan( |
| 264 | + ReductionOpType::ContiguousStridedReduce, |
| 265 | + {intermediate.shape(axis)}, |
| 266 | + {intermediate.strides(axis)}); |
| 267 | + strided_reduce_general_dispatch( |
| 268 | + intermediate, out, "sum", plan, {axis}, compute_encoder, d, s); |
| 269 | +} |
| 270 | + |
151 | 271 | void qmm_op( |
152 | 272 | const std::vector<array>& inputs, |
153 | 273 | array& out, |
@@ -211,7 +331,9 @@ void qmm_op( |
211 | 331 | aligned = true; |
212 | 332 | } |
213 | 333 | } else { |
214 | | - if (B < 4) { |
| 334 | + if (B < 4 && D >= 1024 && !gather) { |
| 335 | + return qvm_split_k(inputs, out, group_size, bits, D, O, B, N, s); |
| 336 | + } else if (B < 4) { |
215 | 337 | name += "qvm"; |
216 | 338 | int bo = 64; |
217 | 339 | int bd = 32; |
|
0 commit comments