@@ -56,21 +56,21 @@ def _kernel_quantize_mx4_unpack(
56
56
USE_INT64 (bool): Whether to use int64 for indexing. This is needed for large tensors.
57
57
"""
58
58
# Define Constant Expressions.
59
- FP32_EXP_MASK : tl .constexpr = 0x7F800000 # type: ignore[Incompatible variable type]
60
- FP32_EXP_OFFSET : tl .constexpr = 23 # type: ignore[Incompatible variable type]
61
- FP32_EXP_BIAS : tl .constexpr = 127 # type: ignore[Incompatible variable type]
62
- FP32_SIGN_OFFSET : tl .constexpr = 31 # type: ignore[Incompatible variable type]
59
+ FP16_EXP_MASK : tl .constexpr = 0x7F80 # type: ignore[Incompatible variable type]
60
+ FP16_EXP_OFFSET : tl .constexpr = 7 # type: ignore[Incompatible variable type]
61
+ FP16_EXP_BIAS : tl .constexpr = 127 # type: ignore[Incompatible variable type]
62
+ FP16_SIGN_OFFSET : tl .constexpr = 15 # type: ignore[Incompatible variable type]
63
63
SIGN_MASK : tl .constexpr = 0x1 # type: ignore[Incompatible variable type]
64
- FP32_MANTISSA_MASK : tl .constexpr = 0x007FFFFF # type: ignore[Incompatible variable type]
64
+ FP16_MANTISSA_MASK : tl .constexpr = 0x007F # type: ignore[Incompatible variable type]
65
65
# FP4 has 2 mantissa bits, one explicit one implicit.
66
66
MBITS_IMPLICIT : tl .constexpr = MBITS + 1 # type: ignore[Incompatible variable type]
67
- MAX_FP32_MANTISSA_BITS : tl .constexpr = 24 # type: ignore[Incompatible variable type]
68
- IMPLIED_1_BIT : tl .constexpr = 1 << 23 # type: ignore[Incompatible variable type]
69
- FP32_MIN_NORMAL : tl .constexpr = 2 ** (- 126 ) # type: ignore[Incompatible variable type]
67
+ MAX_FP16_MANTISSA_BITS : tl .constexpr = 8 # type: ignore[Incompatible variable type]
68
+ IMPLIED_1_BIT : tl .constexpr = 1 << 7 # type: ignore[Incompatible variable type]
69
+ FP16_MIN_NORMAL : tl .constexpr = 2 ** (- 126 ) # type: ignore[Incompatible variable type]
70
70
MANTISSA_OVERFLOW_THRESHOLD : tl .constexpr = (1 << MBITS_IMPLICIT ) - 1 # type: ignore[Incompatible variable type]
71
71
EXPONENT_OVERFLOW_THRESHOLD : tl .constexpr = (1 << EBITS ) - 1 # type: ignore[Incompatible variable type]
72
72
IMPLICIT_1_MASK = (1 << (MBITS_IMPLICIT - 1 )) - 1
73
- RAND_MASK : tl .constexpr = (1 << (FP32_EXP_OFFSET - MBITS )) - 1 # type: ignore[Incompatible variable type]
73
+ RAND_MASK : tl .constexpr = (1 << (FP16_EXP_OFFSET - MBITS )) - 1 # type: ignore[Incompatible variable type]
74
74
75
75
# Get the current thread number.
76
76
pid = tl .program_id (0 )
@@ -137,7 +137,7 @@ def _kernel_quantize_mx4_unpack(
137
137
# Compute the shared exponent of each group.
138
138
group_max = tl .max (tl .abs (a_groups ), axis = 1 )
139
139
# Prevent infinite values in log.
140
- group_max = tl .where (group_max == 0 , FP32_MIN_NORMAL , group_max )
140
+ group_max = tl .where (group_max == 0 , FP16_MIN_NORMAL , group_max )
141
141
# Load relevant random values if doing stochastic rounding
142
142
# or stochastic casting.
143
143
group_rand_bits = None
@@ -156,19 +156,20 @@ def _kernel_quantize_mx4_unpack(
156
156
group_exp = tl .clamp (group_exp , - 127 , 125 )
157
157
158
158
# Next we scale A in preparation for quantization.
159
- scale_ = tl .exp2 (group_exp .to (tl .float64 )).to (tl .float32 )
159
+ scale_ = tl .exp2 (group_exp .to (tl .float64 )).to (tl .bfloat16 )
160
160
# Apply scale_ to input. We do this by broadcasting scale.
161
- scaled_a = tl .reshape (a , [GROUP_LOAD , GROUP_SIZE ]) / tl .reshape (
162
- scale_ , [GROUP_LOAD , 1 ]
163
- )
161
+ scaled_a = (
162
+ tl .reshape (a , [GROUP_LOAD , GROUP_SIZE ])
163
+ / tl .reshape (scale_ , [GROUP_LOAD , 1 ])
164
+ ).to (tl .bfloat16 )
164
165
# Reshape back to a flat array.
165
166
scaled_a = tl .reshape (scaled_a , [GROUP_LOAD * GROUP_SIZE ])
166
167
167
168
# We're done with group_exp now so we can write it out.
168
- # We readd fp32_exp_bias for compatibility with cuda dequant.
169
+ # We readd fp16_exp_bias for compatibility with cuda dequant.
169
170
tl .store (
170
171
scale + exp_offset ,
171
- (group_exp + FP32_EXP_BIAS ).to (tl .int8 ),
172
+ (group_exp + FP16_EXP_BIAS ).to (tl .int8 ),
172
173
# Prevent writing outside this chunk or the main array.
173
174
mask = (exp_offset < SCALE_SIZE )
174
175
& (exp_offset < (SCALE_CHUNK_SIZE * (pid + 1 ))),
@@ -179,7 +180,7 @@ def _kernel_quantize_mx4_unpack(
179
180
180
181
# During quantization, we're going to be doing a lot of bitwise operations.
181
182
# This is easier to work with in int32.
182
- scaled_a = scaled_a .to (tl .int32 , bitcast = True )
183
+ scaled_a = scaled_a .to (tl .int16 , bitcast = True )
183
184
184
185
# When doing stochastic downcasting, generate random values for this block
185
186
# and apply it to the mantissa.
@@ -212,28 +213,28 @@ def _kernel_quantize_mx4_unpack(
212
213
# Flatten back to simple array.
213
214
stochastic_round_bits = tl .reshape (
214
215
stochastic_round_bits , [GROUP_LOAD * GROUP_SIZE ]
215
- ).to (tl .int32 , bitcast = True )
216
+ ).to (tl .int16 , bitcast = True )
216
217
217
218
# Mask off mantissa bits of random value and add to mantissa.
218
219
scaled_a = scaled_a + (stochastic_round_bits & RAND_MASK )
219
220
220
221
# Extract sign bit of value.
221
- sign_bit = (scaled_a >> FP32_SIGN_OFFSET ) & SIGN_MASK
222
+ sign_bit = (scaled_a >> FP16_SIGN_OFFSET ) & SIGN_MASK
222
223
223
224
# Extract exponent.
224
- biased_exp = (scaled_a & FP32_EXP_MASK ) >> FP32_EXP_OFFSET
225
+ biased_exp = (scaled_a & FP16_EXP_MASK ) >> FP16_EXP_OFFSET
225
226
226
227
# Extract mantissa.
227
- trailing_mantissa = scaled_a & FP32_MANTISSA_MASK
228
+ trailing_mantissa = scaled_a & FP16_MANTISSA_MASK
228
229
229
230
# Adjust exponent bias for FP4.
230
- new_biased_exp = biased_exp - FP32_EXP_BIAS + FP4_EXP_BIAS
231
+ new_biased_exp = biased_exp - FP16_EXP_BIAS + FP4_EXP_BIAS
231
232
232
233
# Compute difference between ideal exponent and what fp4 can represent.
233
234
exp_diff = tl .where (new_biased_exp <= 0 , 1 - new_biased_exp , 0 )
234
235
235
236
# Clip this difference to maximum number of fp32 mantissa bits.
236
- exp_diff = tl .minimum (exp_diff , MAX_FP32_MANTISSA_BITS )
237
+ exp_diff = tl .minimum (exp_diff , MAX_FP16_MANTISSA_BITS )
237
238
238
239
# Now we round our fp32 mantissa down to fp4.
239
240
is_subnorm = biased_exp == 0
@@ -243,9 +244,9 @@ def _kernel_quantize_mx4_unpack(
243
244
)
244
245
# Compute base number of bits corresponding to the mantissa, smaller for subnorms
245
246
# since implied one is included in exp_diff.
246
- fp32_sig_bits = tl .where (is_subnorm , 23 , 24 ).to (tl .int32 )
247
+ fp16_sig_bits = tl .where (is_subnorm , 7 , 8 ).to (tl .int32 )
247
248
# Now we're ready to shift down to target bitwidth (with an extra bit for rounding).
248
- mantissa = mantissa >> (fp32_sig_bits + exp_diff - MBITS_IMPLICIT - 1 )
249
+ mantissa = mantissa >> (fp16_sig_bits + exp_diff - MBITS_IMPLICIT - 1 )
249
250
# Perform rounding by adding 1 and shifting down.
250
251
mantissa = (mantissa + 1 ) >> 1
251
252
0 commit comments