Skip to content

[QST] About NaNs generated during FP16->FP8 quantization #1766

Open
@alexsamardzic

Description

@alexsamardzic

The reproducer:

import torch
import torchao

from torchao.dtypes import Float8Layout, to_affine_quantized_floatx

x = torch.tensor([[0, 0, 0.1, 0.1]], dtype=torch.float16)

x_aqt = to_affine_quantized_floatx(
    x,
    target_dtype=torch.float8_e5m2,
    block_size=[1, x.shape[1]],
    _layout=Float8Layout(mm_config=None),
)
xq, scale, _ = x_aqt.tensor_impl.get_plain()

print("x =", x)
print("xq =", xq)
print("scale =", scale)

The output is:

x = tensor([[0.0000, 0.0000, 0.1000, 0.1000]], dtype=torch.float16)
xq = tensor([[   nan,    nan, 57344., 57344.]], dtype=torch.float8_e5m2)
scale = tensor([1.7285e-06], dtype=torch.float16)

Basically, the problem is that the quantization code maps [0,0.1] range to [0,57344] (here, 57344 is maximum value for torch.float_e5m2 data type), so the scale gets very small, and then its reciprocal here become Inf, and then 0*Inf produces NaNs as quantized values.

This is all, for course, simply about the range and precision of involved data types, but I was just wondering is this a known issue? Would it make sense to force 0 as input * scale.reciprocal() result here, wherever the corresponding input elements are 0?

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions