Skip to content

Commit 39b04ce

Browse files
author
Awni Hannun
authored
use faster dequant for fp4 qmv (ml-explore#2720)
1 parent d9e6349 commit 39b04ce

File tree

3 files changed

+76
-144
lines changed

3 files changed

+76
-144
lines changed

mlx/backend/metal/kernels/fp4.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@ struct fp4_e2m1 {
4949
}
5050

5151
operator float() {
52-
return FP4_LUT[bits];
52+
half converted = as_type<half>(ushort((bits & 7) << 9));
53+
converted *= 16384.0;
54+
converted = bits & 8 ? -converted : converted;
55+
return converted;
5356
}
5457

5558
uint8_t bits;

mlx/backend/metal/kernels/fp8.h

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,22 @@
11
#pragma once
22

3-
inline float fp32_from_bits(uint32_t bits) {
4-
return *(reinterpret_cast<thread float*>(&bits));
5-
}
6-
inline float fp32_to_bits(float x) {
7-
return *(reinterpret_cast<thread uint32_t*>(&x));
8-
}
9-
103
struct fp8_e4m3 {
114
template <typename T>
125
fp8_e4m3(T f) {
136
// From PyTorch
147
// https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L148
158
uint32_t fp8_max = 543 << 21;
169
uint32_t denorm_mask = 141 << 23;
17-
uint32_t f_bits = fp32_to_bits(static_cast<float>(f));
10+
uint32_t f_bits = as_type<uint32_t>(static_cast<float>(f));
1811
uint32_t sign = f_bits & 0x80000000;
1912
f_bits ^= sign;
2013
if (f_bits >= fp8_max) {
2114
// Default behavior saturates to min/max
2215
bits = 0x7E;
2316
} else {
2417
if (f_bits < (121 << 23)) {
25-
f_bits =
26-
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
18+
f_bits = as_type<uint32_t>(
19+
as_type<float>(f_bits) + as_type<float>(denorm_mask));
2720
bits = static_cast<uint8_t>(f_bits - denorm_mask);
2821
} else {
2922
// resulting mantissa is odd
@@ -53,7 +46,7 @@ struct fp8_e4m3 {
5346
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
5447
inf_nan_mask) &
5548
~zero_mask);
56-
return fp32_from_bits(result);
49+
return as_type<float>(result);
5750
}
5851

5952
uint8_t bits;
@@ -77,11 +70,12 @@ struct fp8_e8m0 {
7770
bits = static_cast<uint8_t>(n + 127);
7871
}
7972

73+
operator bfloat16_t() {
74+
uint16_t out = (bits == 0 ? 0x40 : (static_cast<uint16_t>(bits) << 7));
75+
return as_type<bfloat16_t>(out);
76+
}
8077
operator float() {
81-
if (bits == 0xFF) {
82-
return metal::numeric_limits<float>::quiet_NaN();
83-
}
84-
return metal::ldexp(1.0f, static_cast<int>(bits) - 127);
78+
return static_cast<float>(this->operator bfloat16_t());
8579
}
8680

8781
uint8_t bits;

0 commit comments

Comments
 (0)