Skip to content

Commit b88f22d

Browse files
committed
feat: Float axioms for simplest functions, several lemmas
1 parent 80520e5 commit b88f22d

File tree

4 files changed

+303
-31
lines changed

4 files changed

+303
-31
lines changed

Batteries/Data/Float/Axioms.lean

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/-
2+
Copyright (c) 2025 Robin Arnez. All rights reserved.
3+
Released under Apache 2.0 license as described in the file LICENSE.
4+
Authors: Robin Arnez
5+
-/
6+
7+
import Batteries.Data.Float.Basic
8+
import Lean.Elab.Tactic
9+
10+
/-!
11+
# Axiomatic redefinition of float functions
12+
13+
In this file, the most common floating point functions are
14+
axiomatically redefined to be used to later prove theorems about them.
15+
This is a temporary file, once there actually is a definition in
16+
core Lean, this will become unnecessary.
17+
-/
18+
19+
-- we don't want 10 different axioms for floats, we combine them here into one
20+
private structure Float.AxiomSet where
21+
ofBits_toBits (x : Float) : ofBits x.toBits = x
22+
toBits_ofBits (x : UInt64) : (ofBits x).toBits = if isNaNBits x then 0x7ff8_0000_0000_0000 else x
23+
isNaN_def (x : Float) : x.isNaN = isNaNBits x.toBits
24+
isInf_def (x : Float) : x.isInf = (x.exponentPart = 2047 ∧ x.mantissa = 0)
25+
isFinite_def (x : Float) : x.isFinite = (x.exponentPart < 2047)
26+
neg_def (x : Float) : x.neg = ofBits (x.toBits ^^^ 0x8000_0000_0000_0000)
27+
28+
/--
29+
Auxiliary axiom redefining the opaque `Float` functions.
30+
-/
31+
axiom Float.definitionAxiom : Float.AxiomSet
32+
33+
theorem Float.ofBits_toBits (x : Float) : ofBits x.toBits = x :=
34+
Float.definitionAxiom.ofBits_toBits x
35+
36+
theorem Float.toBits_ofBits (x : UInt64) :
37+
(ofBits x).toBits = if isNaNBits x then 0x7ff8_0000_0000_0000 else x :=
38+
Float.definitionAxiom.toBits_ofBits x
39+
40+
theorem Float.isNaN.eq_def (x : Float) :
41+
x.isNaN = Float.isNaNBits x.toBits :=
42+
Float.definitionAxiom.isNaN_def x
43+
44+
theorem Float.isNaN.eq_1 (x : Float) :
45+
x.isNaN = Float.isNaNBits x.toBits := by
46+
unfold Float.isNaN; rfl
47+
48+
theorem Float.isInf.eq_def (x : Float) :
49+
x.isInf = (x.exponentPart = 2047 ∧ x.mantissa = 0) :=
50+
Float.definitionAxiom.isInf_def x
51+
52+
theorem Float.isInf.eq_1 (x : Float) :
53+
x.isInf = (x.exponentPart = 2047 ∧ x.mantissa = 0) := by
54+
unfold Float.isInf; rfl
55+
56+
theorem Float.isFinite.eq_def (x : Float) :
57+
x.isFinite = (x.exponentPart < 2047) :=
58+
Float.definitionAxiom.isFinite_def x
59+
60+
theorem Float.isFinite.eq_1 (x : Float) :
61+
x.isFinite = (x.exponentPart < 2047) := by
62+
unfold Float.isFinite; rfl
63+
64+
theorem Float.neg.eq_def (x : Float) :
65+
x.neg = ofBits (x.toBits ^^^ 0x8000_0000_0000_0000) :=
66+
Float.definitionAxiom.neg_def x
67+
68+
theorem Float.neg.eq_1 (x : Float) :
69+
x.neg = ofBits (x.toBits ^^^ 0x8000_0000_0000_0000) := by
70+
unfold Float.neg; rfl

Batteries/Data/Float/Basic.lean

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
/-
2+
Copyright (c) 2025 Robin Arnez. All rights reserved.
3+
Released under Apache 2.0 license as described in the file LICENSE.
4+
Authors: Robin Arnez
5+
-/
6+
7+
import Batteries.Data.Nat.Basic
8+
9+
/-!
10+
# Simple functions used in the axiomatic redefinition of floats (temporary).
11+
-/
12+
13+
/-- Returns whether `x` is a NaN bit pattern, that is: `Float.ofBits x = Float.nan` -/
14+
def Float.isNaNBits (x : UInt64) : Bool :=
15+
(x >>> 52) &&& 0x7ff = 0x7ff ∧ x &&& 0x000f_ffff_ffff_ffff ≠ 0
16+
17+
/--
18+
Returns the sign bit of the given floating point number, i.e. whether
19+
the given `Float` is negative. NaN is considered positive in this function.
20+
-/
21+
def Float.signBit (x : Float) : Bool :=
22+
x.toBits >>> 630
23+
24+
/--
25+
Returns the exponent part of the given floating point number which is a value between
26+
`0` and `2047` (inclusive) describing the exponent of the floating point number.
27+
NaN and infinity have exponent part `2047`, `1` has exponent part `1023`, `2` has `1024`, etc.
28+
-/
29+
def Float.exponentPart (x : Float) : UInt64 :=
30+
(x.toBits >>> 52) &&& 0x7ff
31+
32+
/--
33+
Returns the mantissa of the given floating point number (without any implicit digits).
34+
-/
35+
def Float.mantissa (x : Float) : UInt64 :=
36+
x.toBits &&& 0x000f_ffff_ffff_ffff
37+
38+
/--
39+
Constructs a floating point number from the given parts (sign, exponent and mantissa).
40+
This function expects `exponentPart < 2047` and `mantissa < 2 ^ 52` in order to work correctly.
41+
-/
42+
def Float.fromParts (exponentPart : UInt64) (mantissa : UInt64) : Float :=
43+
Float.ofBits ((exponentPart <<< 52) ||| mantissa)
44+
45+
/--
46+
The floating point value "positive infinity", also used to represent numerical computations
47+
which produce finite values outside of the representable range of `Float`.
48+
-/
49+
def Float.inf : Float := fromParts 2047 0
50+
51+
/--
52+
The floating point value "not a number", used to represent erroneous numerical computations
53+
such as `0 / 0`. Using `nan` in any float operation will return `nan`, and all comparisons
54+
involving `nan` except `!=` return `false`, including in particular `nan == nan`.
55+
-/
56+
def Float.nan : Float := fromParts 2047 0x0008_0000_0000_0000
57+
58+
/--
59+
Returns a pair of values `(a, b)` such `x = a / b` (assuming `x` is finite).
60+
-/
61+
def Float.toNumDenPair (x : Float) : Int × Nat :=
62+
let signMul := bif x.signBit then (-1) else 1
63+
let exp := x.exponentPart
64+
if exp = 0 then (x.mantissa.toNat * signMul, 1 <<< 1074)
65+
else
66+
let mant := x.mantissa.toNat ||| 0x0010_0000_0000_0000
67+
(mant <<< (exp.toNat - 1075) * signMul, 1 <<< (1075 - exp.toNat))
68+
69+
/--
70+
Divide two natural numbers, to produce a correctly rounded (nearest-ties-to-even) `Float` result.
71+
-/
72+
def Nat.divFloat (x y : Nat) : Float :=
73+
if y = 0 then if x = 0 then Float.nan else Float.inf else
74+
if x = 0 then Float.fromParts 0 0 else
75+
-- calculate `log2 = ⌊log 2 (x / y)⌋`
76+
let log2 : Int := x.log2 - y.log2
77+
let log2 := if x <<< y.log2 < y <<< x.log2 then log2 - 1 else log2
78+
-- if `x / y ≥ 2 ^ 1024`, return positive infinity
79+
if log2 ≥ 1024 then Float.inf else
80+
81+
let exp := 53 - max log2 (-1022)
82+
-- calculate `mantissa = round (x / y * 2 ^ (exp - 1))` (rounding nearest-ties-to-even)
83+
let num := x <<< exp.toNat
84+
let den := y <<< (-exp).toNat
85+
let div := num / den
86+
let mantissa :=
87+
if div &&& 3 = 1 ∧ div * den = num then div >>> 1 else (div + 1) >>> 1
88+
89+
if log2 < -1022 then
90+
-- subnormal
91+
if mantissa = 0x0010_0000_0000_0000 then -- overflow
92+
Float.fromParts 1 0 -- smallest normal float
93+
else
94+
Float.fromParts 0 mantissa.toUInt64
95+
else
96+
-- normal
97+
if mantissa = 0x0020_0000_0000_0000 then -- overflow
98+
-- also works for infinity
99+
Float.fromParts (log2 + 1024).natAbs.toUInt64 0
100+
else
101+
Float.fromParts (log2 + 1023).natAbs.toUInt64
102+
(mantissa.toUInt64 &&& 0x000f_ffff_ffff_ffff)
103+
104+
/--
105+
Returns `sqrt (x * 2 ^ e)` as a floating point number.
106+
-/
107+
def Float.sqrtHelper (x : Nat) (e : Int) : Float :=
108+
-- log 2 (sqrt (x * 2 ^ e)) = log 2 (x * 2 ^ e) / 2 = ((log 2 x) + e) / 2
109+
let log2 := (x.log2 + e) >>> 1
110+
if log2 ≥ 1024 then Float.inf else
111+
112+
-- we want `mantissa = round (sqrt (x * 2 ^ e) * 2 ^ exp)`
113+
-- round (sqrt (x * 2 ^ e) * 2 ^ exp) =
114+
-- round (sqrt (x * 2 ^ (e + 2 * exp)))
115+
-- we want variables `expInner` and `expOuter` with
116+
-- sqrt (x * 2 ^ (e + 2 * exp)) = sqrt (x * 2 ^ expInner) >>> expOuter
117+
-- TODO: prove that this implementation actually works
118+
let exp := 53 - max log2 (-1022)
119+
let e := e + 2 * exp
120+
let expInner := if e < 0 then (-e).toNat &&& 1 else e.toNat
121+
let expOuter := if e < 0 then (-e).toNat >>> 1 else 0
122+
let val := x <<< expInner
123+
let sqrt := val.sqrt
124+
let result := sqrt >>> expOuter -- result = ⌊sqrt (x * 2 ^ e) * 2 ^ exp * 2⌋
125+
let mantissa :=
126+
if result &&& 3 = 1 ∧ result * result = val then result >>> 1 else (result + 1) >>> 1
127+
128+
if log2 < -1022 then
129+
-- subnormal
130+
if mantissa = 0x0010_0000_0000_0000 then -- overflow
131+
Float.fromParts 1 0 -- smallest normal float
132+
else
133+
Float.fromParts 0 mantissa.toUInt64
134+
else
135+
-- normal
136+
if mantissa = 0x0020_0000_0000_0000 then -- overflow
137+
-- also works for infinity
138+
Float.fromParts (log2 + 1024).natAbs.toUInt64 0
139+
else
140+
Float.fromParts (log2 + 1023).natAbs.toUInt64
141+
(mantissa.toUInt64 &&& 0x000f_ffff_ffff_ffff)

Batteries/Data/Float/Lemmas.lean

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/-
2+
Copyright (c) 2025 Robin Arnez. All rights reserved.
3+
Released under Apache 2.0 license as described in the file LICENSE.
4+
Authors: Robin Arnez
5+
-/
6+
import Batteries.Data.Float.Axioms
7+
8+
theorem Float.toBits_inj {x y : Float} : x.toBits = y.toBits ↔ x = y := by
9+
constructor
10+
· intro h
11+
rw [← Float.ofBits_toBits x, ← Float.ofBits_toBits y, h]
12+
· rintro rfl
13+
rfl
14+
15+
example : (2047 <<< 52 ||| 2251799813685248) >>> 52 &&& 2047 = 2047 := by
16+
simp
17+
18+
theorem Float.toBits_ofBits_of_isNaNBits {x : UInt64} (h : isNaNBits x) :
19+
(ofBits x).toBits = 0x7ff8_0000_0000_0000 := by
20+
simp only [toBits_ofBits, h, reduceIte]
21+
22+
theorem Float.toBits_ofBits_of_not_isNaNBits {x : UInt64} (h : isNaNBits x = false) :
23+
(ofBits x).toBits = x := by
24+
simp only [toBits_ofBits, h, reduceIte, reduceCtorEq]
25+
26+
@[simp]
27+
theorem Float.toBits_nan : nan.toBits = 0x7ff8_0000_0000_0000 := by
28+
simp [nan, fromParts, toBits_ofBits_of_isNaNBits, isNaNBits, ← UInt64.toNat_inj]
29+
30+
@[simp]
31+
theorem Float.isNaN_nan : nan.isNaN := by
32+
rw [isNaN, Float.toBits_nan]; rfl
33+
34+
theorem Float.isNaN_iff_eq_nan (x : Float) : x.isNaN ↔ x = nan := by
35+
unfold isNaN
36+
constructor
37+
· intro h
38+
rw [← Float.ofBits_toBits x]
39+
rw [← Float.toBits_inj, Float.toBits_ofBits]
40+
simp only [h, reduceIte, Float.toBits_nan]
41+
· intro h
42+
rw [h, Float.toBits_nan]
43+
rfl
44+
45+
theorem Float.neg_def (x : Float) : -x = x.neg := rfl
46+
47+
@[simp]
48+
theorem Float.neg_nan : -nan = nan := by
49+
rw [Float.neg_def, neg, toBits_nan]
50+
rw [← Float.isNaN_iff_eq_nan, isNaN, toBits_ofBits]
51+
simp [isNaNBits, ← UInt64.toNat_inj]
52+
53+
protected theorem Float.neg_neg (x : Float) : -(-x) = x := by
54+
by_cases h : x = nan
55+
· rw [h, neg_nan, neg_nan]
56+
· simp only [Float.neg_def, Float.neg]
57+
rw [toBits_ofBits_of_not_isNaNBits, ← Float.toBits_inj, toBits_ofBits_of_not_isNaNBits]
58+
· simp only [← UInt64.toNat_inj, UInt64.toNat_xor, UInt64.reduceToNat]
59+
rw [Nat.xor_assoc, Nat.xor_self, Nat.xor_zero]
60+
repeat
61+
· rw [← Float.isNaN_iff_eq_nan, isNaN] at h
62+
simpa [isNaNBits, ← UInt64.toNat_inj, Nat.shiftRight_xor_distrib,
63+
Nat.and_xor_distrib_right] using h
64+
65+
theorem Float.neg_eq_nan_iff {x : Float} : -x = nan ↔ x = nan := by
66+
constructor
67+
· intro h
68+
rw [← Float.neg_neg x, h, neg_nan]
69+
· intro h
70+
rw [h, neg_nan]
71+
72+
theorem Float.exponentPart_neg (x : Float) : (-x).exponentPart = x.exponentPart := by
73+
by_cases h : x = nan
74+
· rw [h, neg_nan]
75+
· rw [neg_def, neg, exponentPart, exponentPart]
76+
rw [toBits_ofBits_of_not_isNaNBits]
77+
· simp [← UInt64.toNat_inj, Nat.shiftRight_xor_distrib, Nat.and_xor_distrib_right]
78+
· rw [← Float.isNaN_iff_eq_nan, isNaN] at h
79+
simpa [isNaNBits, ← UInt64.toNat_inj, Nat.shiftRight_xor_distrib,
80+
Nat.and_xor_distrib_right] using h
81+
82+
theorem Float.mantissa_neg (x : Float) : (-x).mantissa = x.mantissa := by
83+
by_cases h : x = nan
84+
· rw [h, neg_nan]
85+
· rw [neg_def, neg, mantissa, mantissa]
86+
rw [toBits_ofBits_of_not_isNaNBits]
87+
· simp [← UInt64.toNat_inj, Nat.and_xor_distrib_right]
88+
· rw [← Float.isNaN_iff_eq_nan, isNaN] at h
89+
simpa [isNaNBits, ← UInt64.toNat_inj, Nat.shiftRight_xor_distrib,
90+
Nat.and_xor_distrib_right] using h

Batteries/Lean/Float.lean

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,9 @@
44
Authors: Mario Carneiro
55
-/
66

7-
namespace Float
8-
9-
/--
10-
The floating point value "positive infinity", also used to represent numerical computations
11-
which produce finite values outside of the representable range of `Float`.
12-
-/
13-
def inf : Float := 1/0
7+
import Batteries.Data.Float.Basic
148

15-
/--
16-
The floating point value "not a number", used to represent erroneous numerical computations
17-
such as `0 / 0`. Using `nan` in any float operation will return `nan`, and all comparisons
18-
involving `nan` return `false`, including in particular `nan == nan`.
19-
-/
20-
def nan : Float := 0/0
9+
namespace Float
2110

2211
/-- Returns `v, exp` integers such that `f = v * 2^exp`.
2312
(`e` is not minimal, but `v.abs` will be at most `2^53 - 1`.)
@@ -74,24 +63,6 @@ def toStringFull (f : Float) : String :=
7463

7564
end Float
7665

77-
/--
78-
Divide two natural numbers, to produce a correctly rounded (nearest-ties-to-even) `Float` result.
79-
-/
80-
protected def Nat.divFloat (a b : Nat) : Float :=
81-
if b = 0 then
82-
if a = 0 then Float.nan else Float.inf
83-
else
84-
let ea := a.log2
85-
let eb := b.log2
86-
if eb + 1024 < ea then Float.inf else
87-
let eb' := if b <<< ea ≤ a <<< eb then eb else eb + 1
88-
let mantissa : UInt64 := (a <<< (eb' + 53) / b <<< ea).toUInt64
89-
let rounded := if mantissa &&& 3 == 1 && a <<< (eb' + 53) == mantissa.toNat * (b <<< ea) then
90-
mantissa >>> 1
91-
else
92-
(mantissa + 1) >>> 1
93-
rounded.toFloat.scaleB (ea - (eb' + 52))
94-
9566
/--
9667
Divide two integers, to produce a correctly rounded (nearest-ties-to-even) `Float` result.
9768
-/

0 commit comments

Comments
 (0)