Skip to content

convert_image_dtype overflows with low precision floating point dtypes #6799

Open
@pmeier

Description

@pmeier

While working on improving performance of convert_image_dtype in #6795, I found several cases where convert_image_dtype is silently failing for low precision floating point dtypes torch.float16 and torch.bfloat16:

import torch
from torchvision.transforms import functional as F

# torch.{float16, bfloat16} to any integer dtype
image = torch.tensor(1.0, dtype=torch.float16)
print(image, F.convert_image_dtype(image, torch.uint8), F.convert_image_dtype(image, torch.int8))

# torch.{int32, int64} to torch.float16
image = torch.tensor(2**31 - 1, dtype=torch.int32)
print(image, F.convert_image_dtype(image, torch.float16))
tensor(1., dtype=torch.float16) tensor(0, dtype=torch.uint8) tensor(-128, dtype=torch.int8)
tensor(2147483647, dtype=torch.int32) tensor(nan, dtype=torch.float16)
  1. Converting an valid (b)float16 image in the value range [0.0, 1.0] to any integer dtype overflows the computation. This stems from the fact that eps is fixed:

    eps = 1e-3
    max_val = float(_max_value(dtype))
    result = image.mul(max_val + 1.0 - eps)
    return result.to(dtype)

    This value is simply to large for (b)float16:

    >>> image = torch.tensor(1.0, dtype=torch.float16)
    >>> image.mul(255 + 1.0 - 1e-3)  # float16 -> uint8
    tensor(256., dtype=torch.float16)
    >>> image.to(torch.float32).mul(255 + 1.0 - 1e-3)  # float32 -> uint8
    tensor(255.9990)
    >>> image.mul(255 + 1.0 - 7e-2)  # float16 -> uint8 with adjusted eps
    tensor(255.8750, dtype=torch.float16)

    The whole point of eps is to be as small as possible to have an even value distribution. See Add convert_image_dtype to functionals #2078 (comment) for details.

    We could simply make eps dependent on the input dtype in a function similar to

    def _max_value(dtype: torch.dtype) -> int:

  2. Converting a int{32, 64} image to float16 should not be possible since it can't hold the maximum values:

    >>> torch.finfo(torch.float16).max
    65504.0
    >>> torch.iinfo(torch.int16).max  # ok
    32767
    >>> torch.iinfo(torch.int32).max  # not ok
    2147483647
    >>> torch.iinfo(torch.int64).max  # not ok
    9223372036854775807
    >>> torch.finfo(torch.bfloat16).max  # bfloat does not have this issue
    3.3895313892515355e+38

    We are already raising an error for unsafe float to int conversions

    # float to int
    if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
    image.dtype == torch.float64 and dtype == torch.int64
    ):
    msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
    raise RuntimeError(msg)

    so we could simply do the same here.

cc @vfdev-5 @datumbox

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions