Skip to content

Commit

Permalink
Use exp2 for mx scaling
Browse files Browse the repository at this point in the history
stack-info: PR: #1530, branch: drisspg/stack/26
  • Loading branch information
drisspg committed Jan 15, 2025
1 parent 11333ba commit 9d068d5
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 20 deletions.
13 changes: 3 additions & 10 deletions torchao/prototype/mx_formats/custom_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,13 @@
_f32_to_floatx_unpacked,
_floatx_unpacked_to_f32,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4

# TODO(future): if needed, make the below work on previous PyTorch versions,
# just need to hunt down the previous location of `libdevice`. An assert
# at the callsite prevents usage of this on unsupported versions.
if TORCH_VERSION_AT_LEAST_2_4 and has_triton():
from torch._inductor.runtime.triton_helpers import libdevice

from torchao.prototype.mx_formats.constants import (
E8M0_EXPONENT_BIAS,
E8M0_EXPONENT_NAN_VAL,
F4_E2M1_EXP_BIAS,
F32_EXP_BIAS,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4


def get_bits(x: torch.Tensor) -> str:
Expand Down Expand Up @@ -294,8 +287,8 @@ def triton_f4_to_scaled_bf16_kernel(
s = tl.load(s_ptr + offsets_s, mask=mask_s)

# create the scale in bf16
s_offset = s.to(tl.int16) - e8m0_exponent_bias
s_fp = libdevice.pow(2.0, s_offset).to(tl.bfloat16)
# S is already biased by 127, so we just have to shift it to align w/ bf16
s_fp = (s.to(tl.uint16) << 7).to(tl.bfloat16, bitcast=True)
s_fp = tl.where(s != e8m0_exponent_nan_val, s_fp, float("nan"))

# multiply output by scale
Expand Down
13 changes: 3 additions & 10 deletions torchao/prototype/mx_formats/mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,7 @@ def to_mx(

# For now, calculate the scale in floating point.
# TODO(future) audit if there is a need to bit shift exponents instead.
scale_fp = torch.pow(
torch.full(max_abs.size(), 2.0, device=scale_e8m0_biased.device),
scale_e8m0_unbiased,
)
scale_fp = torch.exp2(scale_e8m0_unbiased).to(torch.float32)

# Today, 2**-127 returns 0 in compile+inductor+triton because it is in the
# float32 denormal range. For now, manually adjust the fp scale. This is
Expand Down Expand Up @@ -176,14 +173,10 @@ def to_mx(


def get_fp_scale(scale_e8m0):
s_offset = scale_e8m0.to(torch.int16) - E8M0_EXPONENT_BIAS
# TODO(later): it would be nice if there was a way to do the 2^x operation
# in PyTorch without creating a tensor of twos
two = torch.full(s_offset.size(), 2.0, device=scale_e8m0.device)
# pow(two, s_offset) can be out of range of floating point formats.
# TODO(later): handle this for float16 if we decide to support float16
# scales.
s_fp = torch.pow(two, s_offset)
s_offset = scale_e8m0.to(torch.int16) - E8M0_EXPONENT_BIAS
s_fp = torch.exp2(s_offset)

# If a block exponent was 255, set values of that block to NaN
s_fp = torch.where(scale_e8m0 != E8M0_EXPONENT_NAN_VAL, s_fp, float("nan"))
Expand Down

0 comments on commit 9d068d5

Please sign in to comment.