Skip to content

Commit 0cb7965

Browse files
committed
Fp qmv (#2984)
1 parent c9031c3 commit 0cb7965

22 files changed

+1050
-162
lines changed

mlx/backend/cpu/quantized.cpp

Lines changed: 156 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,19 @@ namespace mlx::core {
1414

1515
namespace {
1616

17+
array ensure_row_contiguous(
18+
const array& arr,
19+
cpu::CommandEncoder& encoder,
20+
Stream s) {
21+
if (arr.flags().row_contiguous) {
22+
return arr;
23+
} else {
24+
auto arr_cpy = contiguous_copy_cpu(arr, s);
25+
encoder.add_temporary(arr_cpy);
26+
return arr_cpy;
27+
}
28+
};
29+
1730
const static float FP4_LUT[16] = {
1831
+0.0f,
1932
+0.5f,
@@ -922,20 +935,9 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
922935
auto& scales_pre = inputs[2];
923936

924937
auto& encoder = cpu::get_command_encoder(stream());
925-
auto ensure_row_contiguous = [s = stream(), &encoder](const array& arr) {
926-
if (arr.flags().row_contiguous) {
927-
return arr;
928-
} else {
929-
auto arr_cpy = array(arr.shape(), arr.dtype(), nullptr, {});
930-
copy_cpu(arr, arr_cpy, CopyType::General, s);
931-
encoder.add_temporary(arr_cpy);
932-
return arr_cpy;
933-
}
934-
};
935-
936-
auto x = ensure_row_contiguous(x_pre);
937-
auto w = ensure_row_contiguous(w_pre);
938-
auto scales = ensure_row_contiguous(scales_pre);
938+
auto x = ensure_row_contiguous(x_pre, encoder, stream());
939+
auto w = ensure_row_contiguous(w_pre, encoder, stream());
940+
auto scales = ensure_row_contiguous(scales_pre, encoder, stream());
939941

940942
out.set_data(allocator::malloc(out.nbytes()));
941943

@@ -944,7 +946,7 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
944946
encoder.set_input_array(scales);
945947
encoder.set_output_array(out);
946948
if (mode_ == QuantizationMode::Affine) {
947-
auto biases = ensure_row_contiguous(inputs[3]);
949+
auto biases = ensure_row_contiguous(inputs[3], encoder, stream());
948950
encoder.set_input_array(biases);
949951
encoder.dispatch([out = array::unsafe_weak_copy(out),
950952
x = array::unsafe_weak_copy(x),
@@ -1052,6 +1054,105 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
10521054
}
10531055
}
10541056

1057+
uint8_t to_fp8_e8m0(float x) {
1058+
if (!std::isfinite(x)) {
1059+
return 0xFF;
1060+
}
1061+
if (x < 0.0f) {
1062+
return 0x00;
1063+
}
1064+
float le = std::log2(x);
1065+
int n = int(std::round(le));
1066+
1067+
n = n < -127 ? -127 : n;
1068+
n = n > 127 ? 127 : n;
1069+
return static_cast<uint8_t>(n + 127);
1070+
}
1071+
1072+
uint8_t to_fp4_e2m1(float x) {
1073+
if (std::isnan(x)) {
1074+
return 0x7;
1075+
}
1076+
1077+
const uint8_t sign_bit = (std::signbit(x)) ? 0x8 : 0x0;
1078+
x = std::abs(x);
1079+
1080+
uint8_t bits;
1081+
if (x > 5.0f) {
1082+
bits = 0x7;
1083+
} else if (x >= 3.5f) {
1084+
bits = 0x6;
1085+
} else if (x > 2.5f) {
1086+
bits = 0x5;
1087+
} else if (x >= 1.75f) {
1088+
bits = 0x4;
1089+
} else if (x > 1.25f) {
1090+
bits = 0x3;
1091+
} else if (x >= 0.75f) {
1092+
bits = 0x2;
1093+
} else if (x > 0.25f) {
1094+
bits = 0x1;
1095+
} else {
1096+
bits = 0x0;
1097+
}
1098+
return bits | sign_bit;
1099+
}
1100+
1101+
template <typename T>
1102+
void fp_quantize_dequantize(
1103+
const array& w_arr,
1104+
array& out_arr,
1105+
int bits,
1106+
int group_size,
1107+
size_t w_size) {
1108+
auto w = w_arr.data<T>();
1109+
auto out = out_arr.data<T>();
1110+
1111+
size_t n_groups = w_size / group_size;
1112+
1113+
for (size_t i = 0; i < n_groups; ++i) {
1114+
size_t idx = i * group_size;
1115+
float scale = -std::numeric_limits<float>::infinity();
1116+
for (int j = 0; j < group_size; ++j) {
1117+
scale = std::max(scale, std::abs(w[idx + j]));
1118+
}
1119+
scale /= bits == 4 ? 6.0f : 448.0f;
1120+
if (group_size == 16) {
1121+
scale = dequantize_scale<float, 16>(detail::ToFP8()(scale));
1122+
} else {
1123+
scale = dequantize_scale<float, 32>(to_fp8_e8m0(scale));
1124+
}
1125+
1126+
for (int j = 0; j < group_size; ++j) {
1127+
float w_el = scale == 0 ? 0.0f : w[idx + j] / scale;
1128+
float output;
1129+
if (bits == 8) {
1130+
output = detail::FromFP8()(detail::ToFP8()(w_el));
1131+
} else {
1132+
output = FP4_LUT[to_fp4_e2m1(w_el)];
1133+
}
1134+
out[idx + j] = static_cast<T>(scale * output);
1135+
}
1136+
}
1137+
}
1138+
1139+
void dispatch_quantize_dequantize(
1140+
const array& w,
1141+
array& out,
1142+
int bits,
1143+
int group_size) {
1144+
if (w.dtype() == float16) {
1145+
fp_quantize_dequantize<float16_t>(w, out, bits, group_size, w.size());
1146+
} else if (w.dtype() == bfloat16) {
1147+
fp_quantize_dequantize<bfloat16_t>(w, out, bits, group_size, w.size());
1148+
} else if (w.dtype() == float32) {
1149+
fp_quantize_dequantize<float>(w, out, bits, group_size, w.size());
1150+
} else {
1151+
throw std::runtime_error(
1152+
"[quantize_dequantize] Only supports floating point inputs");
1153+
}
1154+
}
1155+
10551156
template <typename T, typename U>
10561157
void quantize(
10571158
const T* w,
@@ -1136,26 +1237,15 @@ void dispatch_quantize(
11361237
void fast::Quantize::eval_cpu(
11371238
const std::vector<array>& inputs,
11381239
std::vector<array>& outputs) {
1139-
auto ensure_row_contiguous = [s = stream()](const array& arr) {
1140-
if (arr.flags().row_contiguous) {
1141-
return std::make_pair(arr, false);
1142-
} else {
1143-
return std::make_pair(contiguous_copy_cpu(arr, s), true);
1144-
}
1145-
};
1146-
1147-
auto [w, copied] = ensure_row_contiguous(inputs[0]);
1240+
auto& encoder = cpu::get_command_encoder(stream());
1241+
auto w = ensure_row_contiguous(inputs[0], encoder, stream());
11481242
auto& out = outputs[0];
11491243
out.set_data(allocator::malloc(out.nbytes()));
11501244

11511245
auto& scales = outputs[1];
11521246
auto& biases = outputs[2];
11531247
scales.set_data(allocator::malloc(scales.nbytes()));
11541248
biases.set_data(allocator::malloc(biases.nbytes()));
1155-
auto& encoder = cpu::get_command_encoder(stream());
1156-
if (copied) {
1157-
encoder.add_temporary(w);
1158-
}
11591249
encoder.set_input_array(w);
11601250
encoder.set_input_array(scales);
11611251
encoder.set_input_array(biases);
@@ -1238,6 +1328,43 @@ void fast::ConvertFP8::eval_cpu(
12381328
}
12391329

12401330
void QQMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
1241-
throw std::runtime_error("QQMatmul not implemented on CPU.");
1331+
auto& encoder = cpu::get_command_encoder(stream());
1332+
1333+
bool w_quantized = (inputs[1].dtype() == uint32);
1334+
if (w_quantized && inputs[0].shape(-2) == 1) {
1335+
bool donate_x = inputs[0].is_donatable();
1336+
auto x = ensure_row_contiguous(inputs[0], encoder, stream());
1337+
auto w = ensure_row_contiguous(inputs[1], encoder, stream());
1338+
auto scales = ensure_row_contiguous(inputs[2], encoder, stream());
1339+
1340+
out.set_data(allocator::malloc(out.nbytes()));
1341+
1342+
// If x is a copy it should be donatable
1343+
donate_x |= x.is_donatable();
1344+
auto xhat = donate_x
1345+
? x
1346+
: array(allocator::malloc(x.nbytes()), x.shape(), x.dtype());
1347+
if (!donate_x) {
1348+
encoder.add_temporary(xhat);
1349+
}
1350+
encoder.set_input_array(x);
1351+
encoder.set_input_array(w);
1352+
encoder.set_input_array(scales);
1353+
encoder.set_output_array(out);
1354+
encoder.dispatch([out = array::unsafe_weak_copy(out),
1355+
x = array::unsafe_weak_copy(x),
1356+
xhat = array::unsafe_weak_copy(xhat),
1357+
w = array::unsafe_weak_copy(w),
1358+
scales = array::unsafe_weak_copy(scales),
1359+
group_size_ = group_size_,
1360+
bits_ = bits_]() mutable {
1361+
dispatch_quantize_dequantize(x, xhat, bits_, group_size_);
1362+
fp_qmm_dispatch(out, xhat, w, scales, group_size_, bits_, true);
1363+
});
1364+
return;
1365+
} else {
1366+
throw std::runtime_error("[QQMatmul] NYI for the general case");
1367+
}
12421368
}
1369+
12431370
} // namespace mlx::core

mlx/backend/cuda/CMakeLists.txt

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,10 @@ target_sources(
5656
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
5757
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
5858
${CMAKE_CURRENT_SOURCE_DIR}/quantized/fp_quantize.cu
59+
${CMAKE_CURRENT_SOURCE_DIR}/quantized/qmv.cu
5960
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
61+
${CMAKE_CURRENT_SOURCE_DIR}/quantized/qqmm.cpp
62+
${CMAKE_CURRENT_SOURCE_DIR}/quantized/qqmm_utils.cu
6063
${CMAKE_CURRENT_SOURCE_DIR}/quantized/convert_fp8.cu
6164
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
6265

@@ -66,12 +69,12 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary)
6669
# fp4 is not available on < 12.8
6770
if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12.8.0)
6871
target_include_directories(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/)
72+
target_sources(mlx
73+
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/no_qqmm_impl.cpp)
6974
else()
7075
target_sources(
71-
mlx
72-
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/qqmm.cpp
73-
${CMAKE_CURRENT_SOURCE_DIR}/quantized/cublas_qqmm.cpp
74-
${CMAKE_CURRENT_SOURCE_DIR}/quantized/qqmm_utils.cu)
76+
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/qqmm_impl.cpp
77+
${CMAKE_CURRENT_SOURCE_DIR}/quantized/cublas_qqmm.cpp)
7578
endif()
7679

7780
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)

mlx/backend/cuda/primitives.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,12 @@ namespace mlx::core {
2424
throw std::runtime_error(#func " has no CUDA implementation."); \
2525
}
2626

27-
#if CUDART_VERSION < 12080
28-
void QQMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
29-
throw std::runtime_error(
30-
"[QQMatmul::eval_gpu] QQMM is only supported with CUDA 12.8 or higher.");
31-
}
32-
#endif
33-
3427
NO_GPU(BlockMaskedMM)
3528
NO_GPU(FFT)
3629
NO_GPU(GatherQMM)
3730
NO_GPU(Hadamard)
3831
NO_GPU_MULTI(LUF)
3932
NO_GPU_MULTI(QRF)
40-
NO_GPU(QuantizedMatmul)
4133
NO_GPU(SegmentedMM)
4234
NO_GPU_MULTI(SVD)
4335
NO_GPU(Inverse)

mlx/backend/cuda/quantized/cuda_fp4.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,20 @@ struct __nv_fp4_e2m1 {
8181
}
8282
uint8_t __x{0};
8383
};
84+
85+
struct __nv_fp4x4_e2m1 {
86+
__device__ operator float4() {
87+
float4 out;
88+
auto bits = __high & 0xf;
89+
out.x = float(*(__nv_fp4_e2m1*)(&bits));
90+
bits = (__high >> 4) & 0xf;
91+
out.y = float(*(__nv_fp4_e2m1*)(&bits));
92+
bits = (__low) & 0xf;
93+
out.z = float(*(__nv_fp4_e2m1*)(&bits));
94+
bits = (__low >> 4) & 0xf;
95+
out.w = float(*(__nv_fp4_e2m1*)(&bits));
96+
return out;
97+
}
98+
uint8_t __high{0};
99+
uint8_t __low{0};
100+
};

0 commit comments

Comments
 (0)