Skip to content

XLA (JAX) GPU float32 division precision bug #37181

@ezhulenev

Description

@ezhulenev

TL;DR: a/a != 1.0 in XLA on GPU

"""
Demonstrates JAX GPU float32 division precision bug.
IEEE 754 guarantees a / a = 1.0, but JAX GPU returns ~0.999999940395355
"""
import jax
import jax.numpy as jnp

print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")

# Test on GPU
print("\n=== GPU ===")
with jax.default_device(jax.devices('gpu')[0]):
    a = jnp.float32(3163.328613)
    result = a / a
    print(f"{float(a)} / {float(a)} = {float(result):.15f}")
    print(f"Error: {float(result) - 1.0:+.2e}")

# Test on CPU
print("\n=== CPU ===")
with jax.default_device(jax.devices('cpu')[0]):
    a = jnp.float32(3163.328613)
    result = a / a
    print(f"{float(a)} / {float(a)} = {float(result):.15f}")
    print(f"Error: {float(result) - 1.0:+.2e}")

LLVM IR:

; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: readwrite)
define ptx_kernel void @wrapped_divide(ptr noalias readonly align 16 captures(none) dereferenceable(4) %0, ptr noalias readonly align 16 captures(none) dereferenceable(4) %1, ptr noalias writeonly align 256 captures(none) dereferenceable(4) initializes((0, 4)) %2) local_unnamed_addr #0 {
  %4 = addrspacecast ptr %0 to ptr addrspace(1)
  %5 = addrspacecast ptr %1 to ptr addrspace(1)
  %6 = addrspacecast ptr %2 to ptr addrspace(1)
  %7 = load float, ptr addrspace(1) %4, align 16, !invariant.load !2
  %8 = load float, ptr addrspace(1) %5, align 16, !invariant.load !2
  %9 = fdiv float %7, %8
  store float %9, ptr addrspace(1) %6, align 256
  ret void
}

PTX:

//
// Generated by LLVM NVPTX Back-End
//

.version 8.8
.target sm_90a
.address_size 64

	// .globl	wrapped_divide

.visible .entry wrapped_divide(
	.param .u64 .ptr .align 16 wrapped_divide_param_0,
	.param .u64 .ptr .align 16 wrapped_divide_param_1,
	.param .u64 .ptr .align 256 wrapped_divide_param_2
)
.reqntid 1, 1, 1
{
	.reg .b32 	%r<4>;
	.reg .b64 	%rd<7>;

	ld.param.b64 	%rd1, [wrapped_divide_param_0];
	cvta.to.global.u64 	%rd2, %rd1;
	ld.param.b64 	%rd3, [wrapped_divide_param_1];
	cvta.to.global.u64 	%rd4, %rd3;
	ld.param.b64 	%rd5, [wrapped_divide_param_2];
	cvta.to.global.u64 	%rd6, %rd5;
	ld.global.nc.b32 	%r1, [%rd2];
	ld.global.nc.b32 	%r2, [%rd4];
	div.full.f32 	%r3, %r1, %r2;
	st.global.b32 	[%rd6], %r3;
	ret;

}

The root cause is:

div.full.f32 implements a relatively fast, full-range approximation that scales operands to achieve better accuracy, but is not fully IEEE 754 compliant and does not support rounding modifiers. The maximum ulp error is 2 across the full range of inputs.

Metadata

Metadata

Labels

GPUXLA on GPUbugSomething isn't workingstat:awaiting openxla-engAwaiting response from openxla-eng

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions