@@ -127,10 +127,7 @@ def to_mx(
127
127
128
128
# For now, calculate the scale in floating point.
129
129
# TODO(future) audit if there is a need to bit shift exponents instead.
130
- scale_fp = torch .pow (
131
- torch .full (max_abs .size (), 2.0 , device = scale_e8m0_biased .device ),
132
- scale_e8m0_unbiased ,
133
- )
130
+ scale_fp = torch .exp2 (scale_e8m0_unbiased ).to (torch .float32 )
134
131
135
132
# Today, 2**-127 returns 0 in compile+inductor+triton because it is in the
136
133
# float32 denormal range. For now, manually adjust the fp scale. This is
@@ -176,14 +173,10 @@ def to_mx(
176
173
177
174
178
175
def get_fp_scale (scale_e8m0 ):
179
- s_offset = scale_e8m0 .to (torch .int16 ) - E8M0_EXPONENT_BIAS
180
- # TODO(later): it would be nice if there was a way to do the 2^x operation
181
- # in PyTorch without creating a tensor of twos
182
- two = torch .full (s_offset .size (), 2.0 , device = scale_e8m0 .device )
183
- # pow(two, s_offset) can be out of range of floating point formats.
184
176
# TODO(later): handle this for float16 if we decide to support float16
185
177
# scales.
186
- s_fp = torch .pow (two , s_offset )
178
+ s_offset = scale_e8m0 .to (torch .int16 ) - E8M0_EXPONENT_BIAS
179
+ s_fp = torch .exp2 (s_offset )
187
180
188
181
# If a block exponent was 255, set values of that block to NaN
189
182
s_fp = torch .where (scale_e8m0 != E8M0_EXPONENT_NAN_VAL , s_fp , float ("nan" ))
0 commit comments