Skip to content

Commit f52ab82

Browse files
Tianyu Liangfacebook-github-bot
Tianyu Liang
authored andcommitted
Support BF16 in Triton downcast quantization mx4 unpack kernel (pytorch#4203)
Summary: Pull Request resolved: pytorch#4203 X-link: facebookresearch/FBGEMM#1279 As title Reviewed By: jiawenliu64 Differential Revision: D75092541 fbshipit-source-id: 07c73809d3ca096eb2ecac999baec836dd8db388
1 parent e653778 commit f52ab82

File tree

3 files changed

+26
-27
lines changed

3 files changed

+26
-27
lines changed

fbgemm_gpu/experimental/gemm/test/fp4_quantize_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ def _test_quantize_fp4(
3535
rounding_mode = RoundingMode.even
3636
packed_group_size = group_size // 2
3737
groups_per_row = math.ceil(N / group_size)
38-
3938
x = torch.randn(M, N, dtype=torch.bfloat16, device=device)
4039
xq_ref, x_scale_ref = triton_quantize_mx4_unpack(
4140
x, group_size=group_size, rounding_mode=rounding_mode

fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -56,21 +56,21 @@ def _kernel_quantize_mx4_unpack(
5656
USE_INT64 (bool): Whether to use int64 for indexing. This is needed for large tensors.
5757
"""
5858
# Define Constant Expressions.
59-
FP32_EXP_MASK: tl.constexpr = 0x7F800000 # type: ignore[Incompatible variable type]
60-
FP32_EXP_OFFSET: tl.constexpr = 23 # type: ignore[Incompatible variable type]
61-
FP32_EXP_BIAS: tl.constexpr = 127 # type: ignore[Incompatible variable type]
62-
FP32_SIGN_OFFSET: tl.constexpr = 31 # type: ignore[Incompatible variable type]
59+
FP16_EXP_MASK: tl.constexpr = 0x7F80 # type: ignore[Incompatible variable type]
60+
FP16_EXP_OFFSET: tl.constexpr = 7 # type: ignore[Incompatible variable type]
61+
FP16_EXP_BIAS: tl.constexpr = 127 # type: ignore[Incompatible variable type]
62+
FP16_SIGN_OFFSET: tl.constexpr = 15 # type: ignore[Incompatible variable type]
6363
SIGN_MASK: tl.constexpr = 0x1 # type: ignore[Incompatible variable type]
64-
FP32_MANTISSA_MASK: tl.constexpr = 0x007FFFFF # type: ignore[Incompatible variable type]
64+
FP16_MANTISSA_MASK: tl.constexpr = 0x007F # type: ignore[Incompatible variable type]
6565
# FP4 has 2 mantissa bits, one explicit one implicit.
6666
MBITS_IMPLICIT: tl.constexpr = MBITS + 1 # type: ignore[Incompatible variable type]
67-
MAX_FP32_MANTISSA_BITS: tl.constexpr = 24 # type: ignore[Incompatible variable type]
68-
IMPLIED_1_BIT: tl.constexpr = 1 << 23 # type: ignore[Incompatible variable type]
69-
FP32_MIN_NORMAL: tl.constexpr = 2 ** (-126) # type: ignore[Incompatible variable type]
67+
MAX_FP16_MANTISSA_BITS: tl.constexpr = 8 # type: ignore[Incompatible variable type]
68+
IMPLIED_1_BIT: tl.constexpr = 1 << 7 # type: ignore[Incompatible variable type]
69+
FP16_MIN_NORMAL: tl.constexpr = 2 ** (-126) # type: ignore[Incompatible variable type]
7070
MANTISSA_OVERFLOW_THRESHOLD: tl.constexpr = (1 << MBITS_IMPLICIT) - 1 # type: ignore[Incompatible variable type]
7171
EXPONENT_OVERFLOW_THRESHOLD: tl.constexpr = (1 << EBITS) - 1 # type: ignore[Incompatible variable type]
7272
IMPLICIT_1_MASK = (1 << (MBITS_IMPLICIT - 1)) - 1
73-
RAND_MASK: tl.constexpr = (1 << (FP32_EXP_OFFSET - MBITS)) - 1 # type: ignore[Incompatible variable type]
73+
RAND_MASK: tl.constexpr = (1 << (FP16_EXP_OFFSET - MBITS)) - 1 # type: ignore[Incompatible variable type]
7474

7575
# Get the current thread number.
7676
pid = tl.program_id(0)
@@ -137,7 +137,7 @@ def _kernel_quantize_mx4_unpack(
137137
# Compute the shared exponent of each group.
138138
group_max = tl.max(tl.abs(a_groups), axis=1)
139139
# Prevent infinite values in log.
140-
group_max = tl.where(group_max == 0, FP32_MIN_NORMAL, group_max)
140+
group_max = tl.where(group_max == 0, FP16_MIN_NORMAL, group_max)
141141
# Load relevant random values if doing stochastic rounding
142142
# or stochastic casting.
143143
group_rand_bits = None
@@ -156,19 +156,20 @@ def _kernel_quantize_mx4_unpack(
156156
group_exp = tl.clamp(group_exp, -127, 125)
157157

158158
# Next we scale A in preparation for quantization.
159-
scale_ = tl.exp2(group_exp.to(tl.float64)).to(tl.float32)
159+
scale_ = tl.exp2(group_exp.to(tl.float64)).to(tl.bfloat16)
160160
# Apply scale_ to input. We do this by broadcasting scale.
161-
scaled_a = tl.reshape(a, [GROUP_LOAD, GROUP_SIZE]) / tl.reshape(
162-
scale_, [GROUP_LOAD, 1]
163-
)
161+
scaled_a = (
162+
tl.reshape(a, [GROUP_LOAD, GROUP_SIZE])
163+
/ tl.reshape(scale_, [GROUP_LOAD, 1])
164+
).to(tl.bfloat16)
164165
# Reshape back to a flat array.
165166
scaled_a = tl.reshape(scaled_a, [GROUP_LOAD * GROUP_SIZE])
166167

167168
# We're done with group_exp now so we can write it out.
168-
# We readd fp32_exp_bias for compatibility with cuda dequant.
169+
# We readd fp16_exp_bias for compatibility with cuda dequant.
169170
tl.store(
170171
scale + exp_offset,
171-
(group_exp + FP32_EXP_BIAS).to(tl.int8),
172+
(group_exp + FP16_EXP_BIAS).to(tl.int8),
172173
# Prevent writing outside this chunk or the main array.
173174
mask=(exp_offset < SCALE_SIZE)
174175
& (exp_offset < (SCALE_CHUNK_SIZE * (pid + 1))),
@@ -179,7 +180,7 @@ def _kernel_quantize_mx4_unpack(
179180

180181
# During quantization, we're going to be doing a lot of bitwise operations.
181182
# This is easier to work with in int32.
182-
scaled_a = scaled_a.to(tl.int32, bitcast=True)
183+
scaled_a = scaled_a.to(tl.int16, bitcast=True)
183184

184185
# When doing stochastic downcasting, generate random values for this block
185186
# and apply it to the mantissa.
@@ -212,28 +213,28 @@ def _kernel_quantize_mx4_unpack(
212213
# Flatten back to simple array.
213214
stochastic_round_bits = tl.reshape(
214215
stochastic_round_bits, [GROUP_LOAD * GROUP_SIZE]
215-
).to(tl.int32, bitcast=True)
216+
).to(tl.int16, bitcast=True)
216217

217218
# Mask off mantissa bits of random value and add to mantissa.
218219
scaled_a = scaled_a + (stochastic_round_bits & RAND_MASK)
219220

220221
# Extract sign bit of value.
221-
sign_bit = (scaled_a >> FP32_SIGN_OFFSET) & SIGN_MASK
222+
sign_bit = (scaled_a >> FP16_SIGN_OFFSET) & SIGN_MASK
222223

223224
# Extract exponent.
224-
biased_exp = (scaled_a & FP32_EXP_MASK) >> FP32_EXP_OFFSET
225+
biased_exp = (scaled_a & FP16_EXP_MASK) >> FP16_EXP_OFFSET
225226

226227
# Extract mantissa.
227-
trailing_mantissa = scaled_a & FP32_MANTISSA_MASK
228+
trailing_mantissa = scaled_a & FP16_MANTISSA_MASK
228229

229230
# Adjust exponent bias for FP4.
230-
new_biased_exp = biased_exp - FP32_EXP_BIAS + FP4_EXP_BIAS
231+
new_biased_exp = biased_exp - FP16_EXP_BIAS + FP4_EXP_BIAS
231232

232233
# Compute difference between ideal exponent and what fp4 can represent.
233234
exp_diff = tl.where(new_biased_exp <= 0, 1 - new_biased_exp, 0)
234235

235236
# Clip this difference to maximum number of fp32 mantissa bits.
236-
exp_diff = tl.minimum(exp_diff, MAX_FP32_MANTISSA_BITS)
237+
exp_diff = tl.minimum(exp_diff, MAX_FP16_MANTISSA_BITS)
237238

238239
# Now we round our fp32 mantissa down to fp4.
239240
is_subnorm = biased_exp == 0
@@ -243,9 +244,9 @@ def _kernel_quantize_mx4_unpack(
243244
)
244245
# Compute base number of bits corresponding to the mantissa, smaller for subnorms
245246
# since implied one is included in exp_diff.
246-
fp32_sig_bits = tl.where(is_subnorm, 23, 24).to(tl.int32)
247+
fp16_sig_bits = tl.where(is_subnorm, 7, 8).to(tl.int32)
247248
# Now we're ready to shift down to target bitwidth (with an extra bit for rounding).
248-
mantissa = mantissa >> (fp32_sig_bits + exp_diff - MBITS_IMPLICIT - 1)
249+
mantissa = mantissa >> (fp16_sig_bits + exp_diff - MBITS_IMPLICIT - 1)
249250
# Perform rounding by adding 1 and shifting down.
250251
mantissa = (mantissa + 1) >> 1
251252

fbgemm_gpu/test/quantize/mx4_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,6 @@ def test_mx4(self, power: int, sizes: int) -> None:
165165
element_format_str = "fp4_e2m1"
166166
ebits, mbits, emax, max_norm, _ = _get_format_params(element_format_str)
167167
scale_bits = 8
168-
169168
# Reference from mx_github
170169
output_ref = fake_quantize_mx(
171170
input,

0 commit comments

Comments
 (0)