You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
kRsqrtF64Budget.gpu.regular = 0 (declared in xla/codegen/intrinsic/accuracy/accuracy_budget.h) is empirically wrong on Blackwell (NVIDIA RTX 5090, compute capability 12.0a, CUDA 12.9, driver 13.2). Over a 1,014-sample sweep of normal-range f64 inputs, 26.04% of results are 1 ULP off the IEEE-correctly-rounded reference for 1/sqrt(x). Only 73.96% are bit-exact.
Interestingly, 1 / sqrt(x) written out as two separate ops is bit-exact on 100% of the same sample, showing the two lowering paths diverge on GPU and exposing both a budget bug and a simplifier-equivalent path where users get dramatically different precision depending on which form they write.
Repro
Minimal JAX reproducer (same pattern as xla#40844 but for GPU f64):
0.7071067811865475 is the round-to-nearest-even f64 of 1/sqrt(2); jax.lax.rsqrt returns the next f64 up. Same pattern holds for x ∈ {1.5, 3.0, 0.5, 7.0, ...} — any input where the SIMD Newton-Raphson refinement lands on the wrong side of the rounding boundary.
Methodology
Swept 1,000 log-uniform f64 inputs over [1e-300, 1e300] plus 14 curated values (including 1.0, 2.0, π, e, 1 + 2^-52). Computed rsqrt(x) and 1/sqrt(x) on the RTX 5090, then compared each result against the reference 1.0 / math.sqrt(x) at the bit level.
verify_rsqrt_f64_blackwell_sweep.exs — the 1,014-sample sweep whose results are reported in the table above.
verify_blackwell_zero_ulp_claims.exs — broader audit across every 0-ULP GPU claim in kSqrtF64Budget, kRsqrtF64Budget, and the StableHLO correctly-rounded op set (add, sub, mul, div). Only kRsqrtF64Budget.gpu.regular = 0 is contradicted on Blackwell; the others all hold bit-exact (see below).
Additional audit result
Run of verify_blackwell_zero_ulp_claims.exs on the same hardware, scoping every 0-ULP f64 GPU claim:
Op
Claim
Observed
Verdict
f64 sqrt
≤0 ULP
0 ULP on 809/809 (100%)
✓ holds
f64 rsqrt
≤0 ULP
1 ULP on 233/809 (28.80%)
✗ FAILS
f32 sqrt
≤1 ULP
1 ULP on 136/807 (16.85%)
✓ holds (not bit-exact)
f64 divide
correctly rounded
0 ULP on 400/400
✓ holds
f64 add
correctly rounded
0 ULP on 400/400
✓ holds
f64 multiply
correctly rounded
0 ULP on 400/400
✓ holds
f64 subtract
correctly rounded
0 ULP on 400/400
✓ holds
So the rsqrt bug is isolated — every other correctly-rounded claim (including f64 sqrt, which is the forward half of 1/sqrt) is honest on Blackwell. This narrows the fix surface to rsqrt itself.
Note also that kSqrtF32Budget.gpu.regular = 1 is consistent with observed behavior; f32 sqrt on Blackwell is not bit-exact, so any attempt to tighten that GPU budget to 0 would regress.
Environment
GPU: NVIDIA GeForce RTX 5090
Compute capability: 12.0a (Blackwell)
Driver: 13.2.0
CUDA Runtime: 12.9.0
CUDA Toolkit: 12.9.0
cuDNN: 9.13.0
XLA: revision bundled in elixir-nx/xla v0.10.0 (early-2026 snapshot of openxla/xla main)
Why this matters
StableHLO defines stablehlo.rsqrt as implementation-defined precision, so a 1-ULP result on Blackwell is not itself a spec violation. The bug is that the budget says otherwise — kRsqrtF64Budget.gpu.regular = 0 is a positive assertion that the emitter is bit-exact, and intrinsic_accuracy_test_gpu relies on it. Either:
The test target hasn't been run on Blackwell and the claim was tuned against older architectures (Volta/Ampere/Hopper) where __nv_rsqrt happens to be correctly rounded for a different set of inputs, or
The test target has a narrower input range than this sweep and doesn't hit the ~26% of Blackwell inputs where the Newton-Raphson body lands on the wrong side.
Separately, the fact that 1/sqrt(x) stays bit-exact while rsqrt(x) does not indicates that the f64 divide + sqrt lowering on GPU is not being rewritten to rsqrt by the algebraic simplifier on this path, despite HandleDivide containing divide(A, sqrt(B)) → multiply(A, rsqrt(B)) with no element-type guard. Worth investigating whether that's intentional (different one-use matching? different HLO shape?) or a latent difference between CPU and GPU pipelines that happens to save precision here.
Proposed fix options
Relax the budget to match reality. Change kRsqrtF64Budget.gpu.regular from 0 to 1, consistent with the CPU budget and with what the spec allows. Run intrinsic_accuracy_test_gpu on a Blackwell CI node to confirm.
Fix the GPU emitter so f64 rsqrt uses a correctly-bounded path (e.g., lower to __nv_sqrt + __nv_frcp_rn composition, same pattern as the CPU fix in [XLA] Fix f64 rsqrt 1-ULP error in CPU intrinsic and algebraic simplifier #40844). This is the same ≤1 ULP guarantee the CPU path now has, but with the Newton-Raphson refinement removed, the result will still be 1 ULP worst-case — so the budget should be relaxed either way.
Option 1 alone is sufficient to unbreak the claim. Option 2 is a real precision improvement for callers who currently hit the Newton-Raphson path.
Summary
kRsqrtF64Budget.gpu.regular = 0(declared inxla/codegen/intrinsic/accuracy/accuracy_budget.h) is empirically wrong on Blackwell (NVIDIA RTX 5090, compute capability 12.0a, CUDA 12.9, driver 13.2). Over a 1,014-sample sweep of normal-range f64 inputs, 26.04% of results are 1 ULP off the IEEE-correctly-rounded reference for1/sqrt(x). Only 73.96% are bit-exact.Interestingly,
1 / sqrt(x)written out as two separate ops is bit-exact on 100% of the same sample, showing the two lowering paths diverge on GPU and exposing both a budget bug and a simplifier-equivalent path where users get dramatically different precision depending on which form they write.Repro
Minimal JAX reproducer (same pattern as xla#40844 but for GPU f64):
0.7071067811865475is the round-to-nearest-even f64 of1/sqrt(2);jax.lax.rsqrtreturns the next f64 up. Same pattern holds for x ∈ {1.5, 3.0, 0.5, 7.0, ...} — any input where the SIMD Newton-Raphson refinement lands on the wrong side of the rounding boundary.Methodology
Swept 1,000 log-uniform f64 inputs over
[1e-300, 1e300]plus 14 curated values (including1.0,2.0,π,e,1 + 2^-52). Computedrsqrt(x)and1/sqrt(x)on the RTX 5090, then compared each result against the reference1.0 / math.sqrt(x)at the bit level.rsqrt(x)f641 / sqrt(x)f64Reproducer scripts (Elixir/EXLA drivers of the same XLA GPU path; the JAX snippet above reproduces single inputs): https://gist.github.com/blasphemetheus/b11c03bbc9361c1f062741a03bbe8af7
verify_rsqrt_f64_blackwell_sweep.exs— the 1,014-sample sweep whose results are reported in the table above.verify_blackwell_zero_ulp_claims.exs— broader audit across every 0-ULP GPU claim inkSqrtF64Budget,kRsqrtF64Budget, and the StableHLO correctly-rounded op set (add,sub,mul,div). OnlykRsqrtF64Budget.gpu.regular = 0is contradicted on Blackwell; the others all hold bit-exact (see below).Additional audit result
Run of
verify_blackwell_zero_ulp_claims.exson the same hardware, scoping every 0-ULP f64 GPU claim:sqrtrsqrtsqrtdivideaddmultiplysubtractSo the rsqrt bug is isolated — every other correctly-rounded claim (including f64
sqrt, which is the forward half of1/sqrt) is honest on Blackwell. This narrows the fix surface to rsqrt itself.Note also that
kSqrtF32Budget.gpu.regular = 1is consistent with observed behavior; f32 sqrt on Blackwell is not bit-exact, so any attempt to tighten that GPU budget to 0 would regress.Environment
main)Why this matters
StableHLO defines
stablehlo.rsqrtas implementation-defined precision, so a 1-ULP result on Blackwell is not itself a spec violation. The bug is that the budget says otherwise —kRsqrtF64Budget.gpu.regular = 0is a positive assertion that the emitter is bit-exact, andintrinsic_accuracy_test_gpurelies on it. Either:__nv_rsqrthappens to be correctly rounded for a different set of inputs, orSeparately, the fact that
1/sqrt(x)stays bit-exact whilersqrt(x)does not indicates that the f64divide + sqrtlowering on GPU is not being rewritten torsqrtby the algebraic simplifier on this path, despiteHandleDividecontainingdivide(A, sqrt(B)) → multiply(A, rsqrt(B))with no element-type guard. Worth investigating whether that's intentional (different one-use matching? different HLO shape?) or a latent difference between CPU and GPU pipelines that happens to save precision here.Proposed fix options
kRsqrtF64Budget.gpu.regularfrom 0 to 1, consistent with the CPU budget and with what the spec allows. Runintrinsic_accuracy_test_gpuon a Blackwell CI node to confirm.__nv_sqrt + __nv_frcp_rncomposition, same pattern as the CPU fix in [XLA] Fix f64 rsqrt 1-ULP error in CPU intrinsic and algebraic simplifier #40844). This is the same ≤1 ULP guarantee the CPU path now has, but with the Newton-Raphson refinement removed, the result will still be 1 ULP worst-case — so the budget should be relaxed either way.Option 1 alone is sufficient to unbreak the claim. Option 2 is a real precision improvement for callers who currently hit the Newton-Raphson path.
Related
divide(A, sqrt(B))simplifier guard for f64)