-
Notifications
You must be signed in to change notification settings - Fork 750
Open
Labels
GPUXLA on GPUXLA on GPUbugSomething isn't workingSomething isn't workingstat:awaiting openxla-engAwaiting response from openxla-engAwaiting response from openxla-eng
Description
TL;DR: a/a != 1.0 in XLA on GPU
"""
Demonstrates JAX GPU float32 division precision bug.
IEEE 754 guarantees a / a = 1.0, but JAX GPU returns ~0.999999940395355
"""
import jax
import jax.numpy as jnp
print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
# Test on GPU
print("\n=== GPU ===")
with jax.default_device(jax.devices('gpu')[0]):
a = jnp.float32(3163.328613)
result = a / a
print(f"{float(a)} / {float(a)} = {float(result):.15f}")
print(f"Error: {float(result) - 1.0:+.2e}")
# Test on CPU
print("\n=== CPU ===")
with jax.default_device(jax.devices('cpu')[0]):
a = jnp.float32(3163.328613)
result = a / a
print(f"{float(a)} / {float(a)} = {float(result):.15f}")
print(f"Error: {float(result) - 1.0:+.2e}")LLVM IR:
; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: readwrite)
define ptx_kernel void @wrapped_divide(ptr noalias readonly align 16 captures(none) dereferenceable(4) %0, ptr noalias readonly align 16 captures(none) dereferenceable(4) %1, ptr noalias writeonly align 256 captures(none) dereferenceable(4) initializes((0, 4)) %2) local_unnamed_addr #0 {
%4 = addrspacecast ptr %0 to ptr addrspace(1)
%5 = addrspacecast ptr %1 to ptr addrspace(1)
%6 = addrspacecast ptr %2 to ptr addrspace(1)
%7 = load float, ptr addrspace(1) %4, align 16, !invariant.load !2
%8 = load float, ptr addrspace(1) %5, align 16, !invariant.load !2
%9 = fdiv float %7, %8
store float %9, ptr addrspace(1) %6, align 256
ret void
}
PTX:
//
// Generated by LLVM NVPTX Back-End
//
.version 8.8
.target sm_90a
.address_size 64
// .globl wrapped_divide
.visible .entry wrapped_divide(
.param .u64 .ptr .align 16 wrapped_divide_param_0,
.param .u64 .ptr .align 16 wrapped_divide_param_1,
.param .u64 .ptr .align 256 wrapped_divide_param_2
)
.reqntid 1, 1, 1
{
.reg .b32 %r<4>;
.reg .b64 %rd<7>;
ld.param.b64 %rd1, [wrapped_divide_param_0];
cvta.to.global.u64 %rd2, %rd1;
ld.param.b64 %rd3, [wrapped_divide_param_1];
cvta.to.global.u64 %rd4, %rd3;
ld.param.b64 %rd5, [wrapped_divide_param_2];
cvta.to.global.u64 %rd6, %rd5;
ld.global.nc.b32 %r1, [%rd2];
ld.global.nc.b32 %r2, [%rd4];
div.full.f32 %r3, %r1, %r2;
st.global.b32 [%rd6], %r3;
ret;
}
The root cause is:
div.full.f32 implements a relatively fast, full-range approximation that scales operands to achieve better accuracy, but is not fully IEEE 754 compliant and does not support rounding modifiers. The maximum ulp error is 2 across the full range of inputs.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
GPUXLA on GPUXLA on GPUbugSomething isn't workingSomething isn't workingstat:awaiting openxla-engAwaiting response from openxla-engAwaiting response from openxla-eng