Open
Description
The reproducer:
import torch
import torchao
from torchao.dtypes import Float8Layout, to_affine_quantized_floatx
x = torch.tensor([[0, 0, 0.1, 0.1]], dtype=torch.float16)
x_aqt = to_affine_quantized_floatx(
x,
target_dtype=torch.float8_e5m2,
block_size=[1, x.shape[1]],
_layout=Float8Layout(mm_config=None),
)
xq, scale, _ = x_aqt.tensor_impl.get_plain()
print("x =", x)
print("xq =", xq)
print("scale =", scale)
The output is:
x = tensor([[0.0000, 0.0000, 0.1000, 0.1000]], dtype=torch.float16)
xq = tensor([[ nan, nan, 57344., 57344.]], dtype=torch.float8_e5m2)
scale = tensor([1.7285e-06], dtype=torch.float16)
Basically, the problem is that the quantization code maps [0,0.1]
range to [0,57344]
(here, 57344 is maximum value for torch.float_e5m2
data type), so the scale gets very small, and then its reciprocal here become Inf
, and then 0*Inf
produces NaN
s as quantized values.
This is all, for course, simply about the range and precision of involved data types, but I was just wondering is this a known issue? Would it make sense to force 0 as input * scale.reciprocal()
result here, wherever the corresponding input elements are 0?