|
| 1 | +#include <arm_fp16.h> |
| 2 | + |
| 3 | +#include <cstdint> |
1 | 4 | #include <iostream>
|
2 | 5 | #include <sstream>
|
3 | 6 |
|
@@ -37,6 +40,12 @@ mx::array quantized_matmul(const mx::array &scales, // Input array scale
|
37 | 40 | if (b.shape().size() != 2) {
|
38 | 41 | throw std::runtime_error("quantized_matmul: b must be a 2D array");
|
39 | 42 | }
|
| 43 | + if (bits != 4) { |
| 44 | + throw std::runtime_error("quantized_matmul: bits must be 4"); |
| 45 | + } |
| 46 | + if (group_size != 64) { |
| 47 | + throw std::runtime_error("quantized_matmul: group_size must be 64"); |
| 48 | + } |
40 | 49 | auto out_shape = a.shape();
|
41 | 50 | if (out_shape.size() != 2) {
|
42 | 51 | throw std::runtime_error("quantized_matmul: a must be a 2D array");
|
@@ -64,17 +73,61 @@ void quantized_matmul_impl(const mx::array &scales, const mx::array &biases, con
|
64 | 73 | encoder.set_input_array(b);
|
65 | 74 | encoder.set_output_array(out);
|
66 | 75 |
|
67 |
| - // Launch the CPU kernel |
68 |
| - encoder.dispatch([a_ptr = a.data<uint32_t>(), a_shape = a.shape(), a_strides = a.strides(), |
69 |
| - b_ptr = b.data<float16_t>(), b_shape = b.shape(), b_strides = b.strides(), |
70 |
| - out_ptr = out.data<float16_t>(), scales_ptr = scales.data<float16_t>(), |
71 |
| - scales_shape = scales.shape(), scales_strides = scales.strides(), |
72 |
| - biases_ptr = biases.data<float16_t>(), biases_shape = biases.shape(), |
73 |
| - biases_strides = biases.strides(), group_size, bits]() { |
74 |
| - int M = a_shape[0]; |
75 |
| - int N = a_shape[1]; |
76 |
| - int K = b_shape[0]; // because we transposed b |
| 76 | + if (scales.shape() != biases.shape()) { |
| 77 | + throw std::runtime_error("quantized_matmul: scales and biases must have the same shape"); |
| 78 | + } |
| 79 | + if (b.shape()[0] != scales.shape()[0]) { |
| 80 | + throw std::runtime_error("quantized_matmul: b must have the same number of rows as scales"); |
| 81 | + } |
| 82 | + if (b.shape()[1] != scales.shape()[1] * group_size / 8) { |
| 83 | + throw std::runtime_error("quantized_matmul: a must have the same number of columns as scales"); |
| 84 | + } |
77 | 85 |
|
| 86 | + // Launch the CPU kernel |
| 87 | + encoder.dispatch([out_ptr = out.data<float16_t>(), out_shape = out.shape(), out_strides = out.strides(), |
| 88 | + a = mx::array::unsafe_weak_copy(a), b = mx::array::unsafe_weak_copy(b), |
| 89 | + scales = mx::array::unsafe_weak_copy(scales), biases = mx::array::unsafe_weak_copy(biases)]() { |
| 90 | + int M = a.shape()[0]; |
| 91 | + int N = a.shape()[1]; |
| 92 | + int K = b.shape()[0]; |
| 93 | + const int group_size = 64; |
| 94 | + const int bits = 4; |
| 95 | + const int group_per_row = N / group_size; |
| 96 | + const float16_t *a_ptr = a.data<float16_t>(); |
| 97 | + const uint32_t *b_ptr = b.data<uint32_t>(); |
| 98 | + const float16_t *scales_ptr = scales.data<float16_t>(); |
| 99 | + const float16_t *biases_ptr = biases.data<float16_t>(); |
| 100 | + uint32_t item_mask = (1 << bits) - 1; |
| 101 | + for (int i = 0; i < M; i++) { |
| 102 | + for (int k = 0; k < K; k++) { |
| 103 | + for (int group_idx = 0; group_idx < group_per_row; group_idx++) { |
| 104 | + int64_t scales_loc = |
| 105 | + mx::elem_to_loc(k * N / group_size + group_idx, scales.shape(), scales.strides()); |
| 106 | + int64_t biases_loc = |
| 107 | + mx::elem_to_loc(k * N / group_size + group_idx, biases.shape(), biases.strides()); |
| 108 | + float16_t sum = 0; |
| 109 | + float16_t scale = scales_ptr[scales_loc]; |
| 110 | + float16_t bias = biases_ptr[biases_loc]; |
| 111 | + const int packs_per_item = 32 / bits; |
| 112 | + for (int item_idx = 0; item_idx < group_size; item_idx += packs_per_item) { |
| 113 | + int64_t b_loc = |
| 114 | + mx::elem_to_loc((k * N + group_idx * group_size + item_idx) / 8, b.shape(), b.strides()); |
| 115 | + uint32_t b_val = b_ptr[b_loc]; |
| 116 | + uint8_t *b_bytes = reinterpret_cast<uint8_t *>(&b_val); |
| 117 | + for (int pack_idx = 0; pack_idx < packs_per_item; pack_idx++) { |
| 118 | + int64_t a_loc = mx::elem_to_loc(i * N + group_idx * group_size + item_idx + pack_idx, |
| 119 | + a.shape(), a.strides()); |
| 120 | + uint8_t item_val = (b_bytes[pack_idx / 2] >> ((pack_idx % 2) * bits)) & item_mask; |
| 121 | + float16_t b = static_cast<float16_t>(item_val) * scale + bias; |
| 122 | + float16_t a = a_ptr[a_loc]; |
| 123 | + sum += a * b; |
| 124 | + } |
| 125 | + } |
| 126 | + int64_t out_loc = mx::elem_to_loc(i * K + k, out_shape, out_strides); |
| 127 | + out_ptr[out_loc] = sum; |
| 128 | + } |
| 129 | + } |
| 130 | + } |
78 | 131 | });
|
79 | 132 | }
|
80 | 133 |
|
|
0 commit comments