|
| 1 | +/- |
| 2 | +Copyright (c) 2025 Robin Arnez. All rights reserved. |
| 3 | +Released under Apache 2.0 license as described in the file LICENSE. |
| 4 | +Authors: Robin Arnez |
| 5 | +-/ |
| 6 | + |
| 7 | +import Batteries.Data.Nat.Basic |
| 8 | + |
| 9 | +/-! |
| 10 | +# Simple functions used in the axiomatic redefinition of floats (temporary). |
| 11 | +-/ |
| 12 | + |
| 13 | +/-- Returns whether `x` is a NaN bit pattern, that is: `Float.ofBits x = Float.nan` -/ |
| 14 | +def Float.isNaNBits (x : UInt64) : Bool := |
| 15 | + (x >>> 52) &&& 0x7ff = 0x7ff ∧ x &&& 0x000f_ffff_ffff_ffff ≠ 0 |
| 16 | + |
| 17 | +/-- |
| 18 | +Returns the sign bit of the given floating point number, i.e. whether |
| 19 | +the given `Float` is negative. NaN is considered positive in this function. |
| 20 | +-/ |
| 21 | +def Float.signBit (x : Float) : Bool := |
| 22 | + x.toBits >>> 63 ≠ 0 |
| 23 | + |
| 24 | +/-- |
| 25 | +Returns the exponent part of the given floating point number which is a value between |
| 26 | +`0` and `2047` (inclusive) describing the exponent of the floating point number. |
| 27 | +NaN and infinity have exponent part `2047`, `1` has exponent part `1023`, `2` has `1024`, etc. |
| 28 | +-/ |
| 29 | +def Float.exponentPart (x : Float) : UInt64 := |
| 30 | + (x.toBits >>> 52) &&& 0x7ff |
| 31 | + |
| 32 | +/-- |
| 33 | +Returns the mantissa of the given floating point number (without any implicit digits). |
| 34 | +-/ |
| 35 | +def Float.mantissa (x : Float) : UInt64 := |
| 36 | + x.toBits &&& 0x000f_ffff_ffff_ffff |
| 37 | + |
| 38 | +/-- |
| 39 | +Constructs a floating point number from the given parts (sign, exponent and mantissa). |
| 40 | +This function expects `exponentPart < 2047` and `mantissa < 2 ^ 52` in order to work correctly. |
| 41 | +-/ |
| 42 | +def Float.fromParts (exponentPart : UInt64) (mantissa : UInt64) : Float := |
| 43 | + Float.ofBits ((exponentPart <<< 52) ||| mantissa) |
| 44 | + |
| 45 | +/-- |
| 46 | +The floating point value "positive infinity", also used to represent numerical computations |
| 47 | +which produce finite values outside of the representable range of `Float`. |
| 48 | +-/ |
| 49 | +def Float.inf : Float := fromParts 2047 0 |
| 50 | + |
| 51 | +/-- |
| 52 | +The floating point value "not a number", used to represent erroneous numerical computations |
| 53 | +such as `0 / 0`. Using `nan` in any float operation will return `nan`, and all comparisons |
| 54 | +involving `nan` except `!=` return `false`, including in particular `nan == nan`. |
| 55 | +-/ |
| 56 | +def Float.nan : Float := fromParts 2047 0x0008_0000_0000_0000 |
| 57 | + |
| 58 | +/-- |
| 59 | +Returns a pair of values `(a, b)` such `x = a / b` (assuming `x` is finite). |
| 60 | +-/ |
| 61 | +def Float.toNumDenPair (x : Float) : Int × Nat := |
| 62 | + let signMul := bif x.signBit then (-1) else 1 |
| 63 | + let exp := x.exponentPart |
| 64 | + if exp = 0 then (x.mantissa.toNat * signMul, 1 <<< 1074) |
| 65 | + else |
| 66 | + let mant := x.mantissa.toNat ||| 0x0010_0000_0000_0000 |
| 67 | + (mant <<< (exp.toNat - 1075) * signMul, 1 <<< (1075 - exp.toNat)) |
| 68 | + |
| 69 | +/-- |
| 70 | +Divide two natural numbers, to produce a correctly rounded (nearest-ties-to-even) `Float` result. |
| 71 | +-/ |
| 72 | +def Nat.divFloat (x y : Nat) : Float := |
| 73 | + if y = 0 then if x = 0 then Float.nan else Float.inf else |
| 74 | + if x = 0 then Float.fromParts 0 0 else |
| 75 | + -- calculate `log2 = ⌊log 2 (x / y)⌋` |
| 76 | + let log2 : Int := x.log2 - y.log2 |
| 77 | + let log2 := if x <<< y.log2 < y <<< x.log2 then log2 - 1 else log2 |
| 78 | + -- if `x / y ≥ 2 ^ 1024`, return positive infinity |
| 79 | + if log2 ≥ 1024 then Float.inf else |
| 80 | + |
| 81 | + let exp := 53 - max log2 (-1022) |
| 82 | + -- calculate `mantissa = round (x / y * 2 ^ (exp - 1))` (rounding nearest-ties-to-even) |
| 83 | + let num := x <<< exp.toNat |
| 84 | + let den := y <<< (-exp).toNat |
| 85 | + let div := num / den |
| 86 | + let mantissa := |
| 87 | + if div &&& 3 = 1 ∧ div * den = num then div >>> 1 else (div + 1) >>> 1 |
| 88 | + |
| 89 | + if log2 < -1022 then |
| 90 | + -- subnormal |
| 91 | + if mantissa = 0x0010_0000_0000_0000 then -- overflow |
| 92 | + Float.fromParts 1 0 -- smallest normal float |
| 93 | + else |
| 94 | + Float.fromParts 0 mantissa.toUInt64 |
| 95 | + else |
| 96 | + -- normal |
| 97 | + if mantissa = 0x0020_0000_0000_0000 then -- overflow |
| 98 | + -- also works for infinity |
| 99 | + Float.fromParts (log2 + 1024).natAbs.toUInt64 0 |
| 100 | + else |
| 101 | + Float.fromParts (log2 + 1023).natAbs.toUInt64 |
| 102 | + (mantissa.toUInt64 &&& 0x000f_ffff_ffff_ffff) |
| 103 | + |
| 104 | +/-- |
| 105 | +Returns `sqrt (x * 2 ^ e)` as a floating point number. |
| 106 | +-/ |
| 107 | +def Float.sqrtHelper (x : Nat) (e : Int) : Float := |
| 108 | + -- log 2 (sqrt (x * 2 ^ e)) = log 2 (x * 2 ^ e) / 2 = ((log 2 x) + e) / 2 |
| 109 | + let log2 := (x.log2 + e) >>> 1 |
| 110 | + if log2 ≥ 1024 then Float.inf else |
| 111 | + |
| 112 | + -- we want `mantissa = round (sqrt (x * 2 ^ e) * 2 ^ exp)` |
| 113 | + -- round (sqrt (x * 2 ^ e) * 2 ^ exp) = |
| 114 | + -- round (sqrt (x * 2 ^ (e + 2 * exp))) |
| 115 | + -- we want variables `expInner` and `expOuter` with |
| 116 | + -- sqrt (x * 2 ^ (e + 2 * exp)) = sqrt (x * 2 ^ expInner) >>> expOuter |
| 117 | + -- TODO: prove that this implementation actually works |
| 118 | + let exp := 53 - max log2 (-1022) |
| 119 | + let e := e + 2 * exp |
| 120 | + let expInner := if e < 0 then (-e).toNat &&& 1 else e.toNat |
| 121 | + let expOuter := if e < 0 then (-e).toNat >>> 1 else 0 |
| 122 | + let val := x <<< expInner |
| 123 | + let sqrt := val.sqrt |
| 124 | + let result := sqrt >>> expOuter -- result = ⌊sqrt (x * 2 ^ e) * 2 ^ exp * 2⌋ |
| 125 | + let mantissa := |
| 126 | + if result &&& 3 = 1 ∧ result * result = val then result >>> 1 else (result + 1) >>> 1 |
| 127 | + |
| 128 | + if log2 < -1022 then |
| 129 | + -- subnormal |
| 130 | + if mantissa = 0x0010_0000_0000_0000 then -- overflow |
| 131 | + Float.fromParts 1 0 -- smallest normal float |
| 132 | + else |
| 133 | + Float.fromParts 0 mantissa.toUInt64 |
| 134 | + else |
| 135 | + -- normal |
| 136 | + if mantissa = 0x0020_0000_0000_0000 then -- overflow |
| 137 | + -- also works for infinity |
| 138 | + Float.fromParts (log2 + 1024).natAbs.toUInt64 0 |
| 139 | + else |
| 140 | + Float.fromParts (log2 + 1023).natAbs.toUInt64 |
| 141 | + (mantissa.toUInt64 &&& 0x000f_ffff_ffff_ffff) |
0 commit comments