Skip to content

Commit 5bb2abf

Browse files
committed
Use exp2 for mx scaling
stack-info: PR: #1530, branch: drisspg/stack/26
1 parent 4996101 commit 5bb2abf

File tree

1 file changed

+3
-10
lines changed

1 file changed

+3
-10
lines changed

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,7 @@ def to_mx(
127127

128128
# For now, calculate the scale in floating point.
129129
# TODO(future) audit if there is a need to bit shift exponents instead.
130-
scale_fp = torch.pow(
131-
torch.full(max_abs.size(), 2.0, device=scale_e8m0_biased.device),
132-
scale_e8m0_unbiased,
133-
)
130+
scale_fp = torch.exp2(scale_e8m0_unbiased).to(torch.float32)
134131

135132
# Today, 2**-127 returns 0 in compile+inductor+triton because it is in the
136133
# float32 denormal range. For now, manually adjust the fp scale. This is
@@ -176,14 +173,10 @@ def to_mx(
176173

177174

178175
def get_fp_scale(scale_e8m0):
179-
s_offset = scale_e8m0.to(torch.int16) - E8M0_EXPONENT_BIAS
180-
# TODO(later): it would be nice if there was a way to do the 2^x operation
181-
# in PyTorch without creating a tensor of twos
182-
two = torch.full(s_offset.size(), 2.0, device=scale_e8m0.device)
183-
# pow(two, s_offset) can be out of range of floating point formats.
184176
# TODO(later): handle this for float16 if we decide to support float16
185177
# scales.
186-
s_fp = torch.pow(two, s_offset)
178+
s_offset = scale_e8m0.to(torch.int16) - E8M0_EXPONENT_BIAS
179+
s_fp = torch.exp2(s_offset)
187180

188181
# If a block exponent was 255, set values of that block to NaN
189182
s_fp = torch.where(scale_e8m0 != E8M0_EXPONENT_NAN_VAL, s_fp, float("nan"))

0 commit comments

Comments
 (0)