@@ -14,6 +14,19 @@ namespace mlx::core {
1414
1515namespace {
1616
17+ array ensure_row_contiguous (
18+ const array& arr,
19+ cpu::CommandEncoder& encoder,
20+ Stream s) {
21+ if (arr.flags ().row_contiguous ) {
22+ return arr;
23+ } else {
24+ auto arr_cpy = contiguous_copy_cpu (arr, s);
25+ encoder.add_temporary (arr_cpy);
26+ return arr_cpy;
27+ }
28+ };
29+
1730const static float FP4_LUT[16 ] = {
1831 +0 .0f ,
1932 +0 .5f ,
@@ -922,20 +935,9 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
922935 auto & scales_pre = inputs[2 ];
923936
924937 auto & encoder = cpu::get_command_encoder (stream ());
925- auto ensure_row_contiguous = [s = stream (), &encoder](const array& arr) {
926- if (arr.flags ().row_contiguous ) {
927- return arr;
928- } else {
929- auto arr_cpy = array (arr.shape (), arr.dtype (), nullptr , {});
930- copy_cpu (arr, arr_cpy, CopyType::General, s);
931- encoder.add_temporary (arr_cpy);
932- return arr_cpy;
933- }
934- };
935-
936- auto x = ensure_row_contiguous (x_pre);
937- auto w = ensure_row_contiguous (w_pre);
938- auto scales = ensure_row_contiguous (scales_pre);
938+ auto x = ensure_row_contiguous (x_pre, encoder, stream ());
939+ auto w = ensure_row_contiguous (w_pre, encoder, stream ());
940+ auto scales = ensure_row_contiguous (scales_pre, encoder, stream ());
939941
940942 out.set_data (allocator::malloc (out.nbytes ()));
941943
@@ -944,7 +946,7 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
944946 encoder.set_input_array (scales);
945947 encoder.set_output_array (out);
946948 if (mode_ == QuantizationMode::Affine) {
947- auto biases = ensure_row_contiguous (inputs[3 ]);
949+ auto biases = ensure_row_contiguous (inputs[3 ], encoder, stream () );
948950 encoder.set_input_array (biases);
949951 encoder.dispatch ([out = array::unsafe_weak_copy (out),
950952 x = array::unsafe_weak_copy (x),
@@ -1052,6 +1054,105 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
10521054 }
10531055}
10541056
1057+ uint8_t to_fp8_e8m0 (float x) {
1058+ if (!std::isfinite (x)) {
1059+ return 0xFF ;
1060+ }
1061+ if (x < 0 .0f ) {
1062+ return 0x00 ;
1063+ }
1064+ float le = std::log2 (x);
1065+ int n = int (std::round (le));
1066+
1067+ n = n < -127 ? -127 : n;
1068+ n = n > 127 ? 127 : n;
1069+ return static_cast <uint8_t >(n + 127 );
1070+ }
1071+
1072+ uint8_t to_fp4_e2m1 (float x) {
1073+ if (std::isnan (x)) {
1074+ return 0x7 ;
1075+ }
1076+
1077+ const uint8_t sign_bit = (std::signbit (x)) ? 0x8 : 0x0 ;
1078+ x = std::abs (x);
1079+
1080+ uint8_t bits;
1081+ if (x > 5 .0f ) {
1082+ bits = 0x7 ;
1083+ } else if (x >= 3 .5f ) {
1084+ bits = 0x6 ;
1085+ } else if (x > 2 .5f ) {
1086+ bits = 0x5 ;
1087+ } else if (x >= 1 .75f ) {
1088+ bits = 0x4 ;
1089+ } else if (x > 1 .25f ) {
1090+ bits = 0x3 ;
1091+ } else if (x >= 0 .75f ) {
1092+ bits = 0x2 ;
1093+ } else if (x > 0 .25f ) {
1094+ bits = 0x1 ;
1095+ } else {
1096+ bits = 0x0 ;
1097+ }
1098+ return bits | sign_bit;
1099+ }
1100+
1101+ template <typename T>
1102+ void fp_quantize_dequantize (
1103+ const array& w_arr,
1104+ array& out_arr,
1105+ int bits,
1106+ int group_size,
1107+ size_t w_size) {
1108+ auto w = w_arr.data <T>();
1109+ auto out = out_arr.data <T>();
1110+
1111+ size_t n_groups = w_size / group_size;
1112+
1113+ for (size_t i = 0 ; i < n_groups; ++i) {
1114+ size_t idx = i * group_size;
1115+ float scale = -std::numeric_limits<float >::infinity ();
1116+ for (int j = 0 ; j < group_size; ++j) {
1117+ scale = std::max (scale, std::abs (w[idx + j]));
1118+ }
1119+ scale /= bits == 4 ? 6 .0f : 448 .0f ;
1120+ if (group_size == 16 ) {
1121+ scale = dequantize_scale<float , 16 >(detail::ToFP8 ()(scale));
1122+ } else {
1123+ scale = dequantize_scale<float , 32 >(to_fp8_e8m0 (scale));
1124+ }
1125+
1126+ for (int j = 0 ; j < group_size; ++j) {
1127+ float w_el = scale == 0 ? 0 .0f : w[idx + j] / scale;
1128+ float output;
1129+ if (bits == 8 ) {
1130+ output = detail::FromFP8 ()(detail::ToFP8 ()(w_el));
1131+ } else {
1132+ output = FP4_LUT[to_fp4_e2m1 (w_el)];
1133+ }
1134+ out[idx + j] = static_cast <T>(scale * output);
1135+ }
1136+ }
1137+ }
1138+
1139+ void dispatch_quantize_dequantize (
1140+ const array& w,
1141+ array& out,
1142+ int bits,
1143+ int group_size) {
1144+ if (w.dtype () == float16) {
1145+ fp_quantize_dequantize<float16_t >(w, out, bits, group_size, w.size ());
1146+ } else if (w.dtype () == bfloat16) {
1147+ fp_quantize_dequantize<bfloat16_t >(w, out, bits, group_size, w.size ());
1148+ } else if (w.dtype () == float32) {
1149+ fp_quantize_dequantize<float >(w, out, bits, group_size, w.size ());
1150+ } else {
1151+ throw std::runtime_error (
1152+ " [quantize_dequantize] Only supports floating point inputs" );
1153+ }
1154+ }
1155+
10551156template <typename T, typename U>
10561157void quantize (
10571158 const T* w,
@@ -1136,26 +1237,15 @@ void dispatch_quantize(
11361237void fast::Quantize::eval_cpu (
11371238 const std::vector<array>& inputs,
11381239 std::vector<array>& outputs) {
1139- auto ensure_row_contiguous = [s = stream ()](const array& arr) {
1140- if (arr.flags ().row_contiguous ) {
1141- return std::make_pair (arr, false );
1142- } else {
1143- return std::make_pair (contiguous_copy_cpu (arr, s), true );
1144- }
1145- };
1146-
1147- auto [w, copied] = ensure_row_contiguous (inputs[0 ]);
1240+ auto & encoder = cpu::get_command_encoder (stream ());
1241+ auto w = ensure_row_contiguous (inputs[0 ], encoder, stream ());
11481242 auto & out = outputs[0 ];
11491243 out.set_data (allocator::malloc (out.nbytes ()));
11501244
11511245 auto & scales = outputs[1 ];
11521246 auto & biases = outputs[2 ];
11531247 scales.set_data (allocator::malloc (scales.nbytes ()));
11541248 biases.set_data (allocator::malloc (biases.nbytes ()));
1155- auto & encoder = cpu::get_command_encoder (stream ());
1156- if (copied) {
1157- encoder.add_temporary (w);
1158- }
11591249 encoder.set_input_array (w);
11601250 encoder.set_input_array (scales);
11611251 encoder.set_input_array (biases);
@@ -1238,6 +1328,43 @@ void fast::ConvertFP8::eval_cpu(
12381328}
12391329
12401330void QQMatmul::eval_cpu (const std::vector<array>& inputs, array& out) {
1241- throw std::runtime_error (" QQMatmul not implemented on CPU." );
1331+ auto & encoder = cpu::get_command_encoder (stream ());
1332+
1333+ bool w_quantized = (inputs[1 ].dtype () == uint32);
1334+ if (w_quantized && inputs[0 ].shape (-2 ) == 1 ) {
1335+ bool donate_x = inputs[0 ].is_donatable ();
1336+ auto x = ensure_row_contiguous (inputs[0 ], encoder, stream ());
1337+ auto w = ensure_row_contiguous (inputs[1 ], encoder, stream ());
1338+ auto scales = ensure_row_contiguous (inputs[2 ], encoder, stream ());
1339+
1340+ out.set_data (allocator::malloc (out.nbytes ()));
1341+
1342+ // If x is a copy it should be donatable
1343+ donate_x |= x.is_donatable ();
1344+ auto xhat = donate_x
1345+ ? x
1346+ : array (allocator::malloc (x.nbytes ()), x.shape (), x.dtype ());
1347+ if (!donate_x) {
1348+ encoder.add_temporary (xhat);
1349+ }
1350+ encoder.set_input_array (x);
1351+ encoder.set_input_array (w);
1352+ encoder.set_input_array (scales);
1353+ encoder.set_output_array (out);
1354+ encoder.dispatch ([out = array::unsafe_weak_copy (out),
1355+ x = array::unsafe_weak_copy (x),
1356+ xhat = array::unsafe_weak_copy (xhat),
1357+ w = array::unsafe_weak_copy (w),
1358+ scales = array::unsafe_weak_copy (scales),
1359+ group_size_ = group_size_,
1360+ bits_ = bits_]() mutable {
1361+ dispatch_quantize_dequantize (x, xhat, bits_, group_size_);
1362+ fp_qmm_dispatch (out, xhat, w, scales, group_size_, bits_, true );
1363+ });
1364+ return ;
1365+ } else {
1366+ throw std::runtime_error (" [QQMatmul] NYI for the general case" );
1367+ }
12421368}
1369+
12431370} // namespace mlx::core
0 commit comments