Description
While working on improving performance of convert_image_dtype
in #6795, I found several cases where convert_image_dtype
is silently failing for low precision floating point dtypes torch.float16
and torch.bfloat16
:
import torch
from torchvision.transforms import functional as F
# torch.{float16, bfloat16} to any integer dtype
image = torch.tensor(1.0, dtype=torch.float16)
print(image, F.convert_image_dtype(image, torch.uint8), F.convert_image_dtype(image, torch.int8))
# torch.{int32, int64} to torch.float16
image = torch.tensor(2**31 - 1, dtype=torch.int32)
print(image, F.convert_image_dtype(image, torch.float16))
tensor(1., dtype=torch.float16) tensor(0, dtype=torch.uint8) tensor(-128, dtype=torch.int8)
tensor(2147483647, dtype=torch.int32) tensor(nan, dtype=torch.float16)
-
Converting an valid (b)float16 image in the value range
[0.0, 1.0]
to any integer dtype overflows the computation. This stems from the fact thateps
is fixed:vision/torchvision/transforms/functional_tensor.py
Lines 90 to 93 in 7a62a54
This value is simply to large for (b)float16:
>>> image = torch.tensor(1.0, dtype=torch.float16) >>> image.mul(255 + 1.0 - 1e-3) # float16 -> uint8 tensor(256., dtype=torch.float16) >>> image.to(torch.float32).mul(255 + 1.0 - 1e-3) # float32 -> uint8 tensor(255.9990) >>> image.mul(255 + 1.0 - 7e-2) # float16 -> uint8 with adjusted eps tensor(255.8750, dtype=torch.float16)
The whole point of
eps
is to be as small as possible to have an even value distribution. See Add convert_image_dtype to functionals #2078 (comment) for details.We could simply make
eps
dependent on the input dtype in a function similar to -
Converting a int{32, 64} image to float16 should not be possible since it can't hold the maximum values:
>>> torch.finfo(torch.float16).max 65504.0 >>> torch.iinfo(torch.int16).max # ok 32767 >>> torch.iinfo(torch.int32).max # not ok 2147483647 >>> torch.iinfo(torch.int64).max # not ok 9223372036854775807 >>> torch.finfo(torch.bfloat16).max # bfloat does not have this issue 3.3895313892515355e+38
We are already raising an error for unsafe float to int conversions
vision/torchvision/transforms/functional_tensor.py
Lines 78 to 83 in 7a62a54
so we could simply do the same here.
Activity