diff --git a/src/FloatingPointExpr.jl b/src/FloatingPointExpr.jl new file mode 100644 index 0000000..113d011 --- /dev/null +++ b/src/FloatingPointExpr.jl @@ -0,0 +1,237 @@ +import Base: +, -, *, /, ^, div, inv, mod, abs, ==, !=, promote_rule, convert + +abstract type NumericExpr <: AbstractExpr end + +# Global dictionary to track variable names for FloatingPointExpr +GLOBAL_VARNAMES = Dict{Type, Vector{String}}() +WARN_DUPLICATE_NAMES = true + +# Mapping rounding mode +const ROUNDING_MODE_MAP = Dict( + :RNE => :round_nearest_ties_to_even, + :RNA => :round_nearest_ties_to_away, + :RTP => :round_toward_positive, + :RTN => :round_toward_negative, + :RTZ => :round_toward_zero +) + +""" + FloatingPointExpr + +Represents a floating-point expression with support for operations, precision, and rounding modes. + +### Arguments: +- `op`: Symbol representing the operation (e.g., `:add`, `:mul`, etc.). +- `children`: Array of child expressions (inputs to the operation). +- `value`: The numerical value of the expression (can be `nothing` or `missing`). +- `name`: A unique name for the expression. +- `__is_commutative`: Boolean indicating whether the operation is commutative. +- `eb`: Exponent bit width. +- `sb`: Significand bit width (includes the hidden bit). +- `rounding_mode`: Symbol specifying the rounding mode. +""" +mutable struct FloatingPointExpr <: NumericExpr + op :: Symbol + children :: Vector{AbstractExpr} + value :: Union{Float64, Nothing, Missing} + name :: String + __is_commutative :: Bool + eb :: Int + sb :: Int + rounding_mode :: Symbol + + # for convenience + FloatingPointExpr(op::Symbol, children::Vector{T}, + value::Union{Float64, Nothing, Missing}, + name::String, __is_commutative::Bool, + eb::Int, sb::Int, rounding_mode::Symbol) where T <: AbstractExpr = new(op, children, value, name, __is_commutative, eb, sb, rounding_mode) +end + +""" + FloatingPointExpr(name::String; eb=11, sb=53, rounding_mode=:RNE) + +Create a new `FloatingPointExpr` instance with the given name, exponent bit width (`eb`), significand bit width (`sb`), and rounding mode. + +### Arguments: +- `name::String`: A unique name for the expression. +- `eb::Int`: Exponent bit width (default is 11). +- `sb::Int`: Significand bit width (default is 53). +- `rounding_mode::Symbol`: Rounding mode used in the operation (default is `:RNE`). + +```julia +expr = FloatingPointExpr("my_expr", eb=8, sb=24, rounding_mode=:RTP) +``` +""" +function FloatingPointExpr(name::String; value::Union{Float64, Nothing, Missing}=nothing, eb::Int=11, sb::Int=53, rounding_mode::Symbol=:RNE) + if haskey(ROUNDING_MODE_MAP, rounding_mode) + rounding_mode = ROUNDING_MODE_MAP[rounding_mode] + elseif !(rounding_mode in values(ROUNDING_MODE_MAP)) + throw(ArgumentError("Invalid rounding mode. See docstring for valid modes.")) + end + + # Ensure global variable tracking + if !haskey(GLOBAL_VARNAMES, FloatingPointExpr) + GLOBAL_VARNAMES[FloatingPointExpr] = String[] + end + + if name in GLOBAL_VARNAMES[FloatingPointExpr] + WARN_DUPLICATE_NAMES && @warn("Duplicate variable name: $name") + else + push!(GLOBAL_VARNAMES[FloatingPointExpr], name) + end + + return FloatingPointExpr(:identity, AbstractExpr[], value, name, false, eb, sb, rounding_mode) +end + +# Special FloatingPoint constants +""" + fp_zero(eb::Int=11, sb::Int=53, sign::Bool=false) + +Create a FloatingPointExpr representing zero, with optional sign (sign=true for negative zero). +""" +function fp_zero(eb::Int=11, sb::Int=53, sign::Bool=false) + value = sign ? -0.0 : 0.0 + FloatingPointExpr(:zero, [], value, "zero", true, eb, sb, :round_nearest_ties_to_even) +end + +""" + fp_infinity(eb::Int=11, sb::Int=53, sign::Bool=false) + +Create a FloatingPointExpr representing infinity, with optional sign (sign=true for negative infinity). +""" +function fp_infinity(eb::Int=11, sb::Int=53, sign::Bool=false) + value = sign ? -Inf : Inf + FloatingPointExpr(:infinity, [], value, "infinity", true, eb, sb, :round_nearest_ties_to_even) +end + +""" + fp_nan(eb::Int=11, sb::Int=53) + +Create a FloatingPointExpr representing NaN (Not a Number). +""" +function fp_nan(eb::Int=11, sb::Int=53) + FloatingPointExpr(:nan, [], NaN, "NaN", true, eb, sb, :round_nearest_ties_to_even) +end + +# FloatingPoint literals +""" + fp_literal(sign::Bool, exponent::Int, significand::Int, eb::Int, sb::Int) + +Create a floating-point literal expression using the given components: +- `sign`: Boolean indicating the sign (false for positive, true for negative) +- `exponent`: The exponent part of the floating-point number +- `significand`: The significand (fractional) part of the floating-point number +- `eb`: The exponent bit width +- `sb`: The significand bit width + +Returns a `FloatingPointExpr` representing the floating-point literal. +""" +function fp_literal(sign::Bool, exponent::Int, significand::Int, eb::Int, sb::Int) + # Calculate the value using ldexp, considering sign, exponent, and significand + value = if sign + -ldexp(significand / (1 << (sb - 1)), exponent - (1 << (eb - 1)) + 1) + else + ldexp(significand / (1 << (sb - 1)), exponent - (1 << (eb - 1)) + 1) + end + # Construct and return a FloatingPointExpr using the correct constructor + return FloatingPointExpr(:literal, AbstractExpr[], value, "literal", false, eb, sb, :round_nearest_ties_to_even) +end + +function is_nan(fp::FloatingPointExpr) + isnan(fp.value) +end + +function is_infinite(fp::FloatingPointExpr) + isinf(fp.value) +end + +function is_zero(fp::FloatingPointExpr) + fp.value == 0.0 +end + +function is_positive(fp::FloatingPointExpr) + fp.value > 0.0 +end + +function is_negative(fp::FloatingPointExpr) + fp.value < 0.0 +end + +# Arithmetic operations +Base.:+(fp1::FloatingPointExpr, fp2::FloatingPointExpr) = begin + if !fp1.__is_commutative && fp2.__is_commutative + fp1, fp2 = fp2, fp1 + end + result = fp1.value + fp2.value + rounded_result = round_float(result, fp1.rounding_mode) + FloatingPointExpr(:add, [fp1, fp2], rounded_result, "add", true, fp1.eb, fp1.sb, fp1.rounding_mode) +end + +Base.:*(fp1::FloatingPointExpr, fp2::FloatingPointExpr) = begin + if !fp1.__is_commutative && fp2.__is_commutative + fp1, fp2 = fp2, fp1 + end + result = fp1.value * fp2.value + rounded_result = round_float(result, fp1.rounding_mode) + FloatingPointExpr(:mul, [fp1, fp2], rounded_result, "mul", true, fp1.eb, fp1.sb, fp1.rounding_mode) +end + +Base.:-(fp1::FloatingPointExpr, fp2::FloatingPointExpr) = begin + result = fp1.value - fp2.value + rounded_result = round_float(result, fp1.rounding_mode) + FloatingPointExpr(:sub, [fp1, fp2], rounded_result, "sub", false, fp1.eb, fp1.sb, fp1.rounding_mode) +end + +Base.:/(fp1::FloatingPointExpr, fp2::FloatingPointExpr) = begin + result = fp1.value / fp2.value + rounded_result = round_float(result, fp1.rounding_mode) + FloatingPointExpr(:div, [fp1, fp2], rounded_result, "div", false, fp1.eb, fp1.sb, fp1.rounding_mode) +end + +# Fused Multiply-Add +function fp_fma(fp1::FloatingPointExpr, fp2::FloatingPointExpr, fp3::FloatingPointExpr) + if !fp1.__is_commutative && fp2.__is_commutative + fp1, fp2 = fp2, fp1 + end + result = fp1.value * fp2.value + fp3.value + rounded_result = round_float(result, fp1.rounding_mode) + FloatingPointExpr(:fma, [fp1, fp2, fp3], rounded_result, "fma", false, fp1.eb, fp1.sb, fp1.rounding_mode) +end + +function round_float(value::Float64, mode::Symbol) + + if mode == :round_nearest_ties_to_even + return round(value) + elseif mode == :round_toward_positive + return ceil(value) + elseif mode == :round_toward_negative + return floor(value) + elseif mode == :round_toward_zero + return trunc(value) + else + throw(ArgumentError("Unsupported rounding mode: $mode")) + end +end + + +Base.convert(::Type{FloatingPointExpr}, x::IntExpr) = begin + val = isnothing(x.value) ? 0.0 : float(x.value) # Default to 0.0 if value is `Nothing` or `Missing`s + op = :identity + children = AbstractExpr[] + name = "convert_from_int_$(x.name)" + eb = 11 + sb = 53 + rounding_mode = :round_nearest_ties_to_even + return FloatingPointExpr(op, children, val, name, true, eb, sb, rounding_mode) +end + +Base.convert(::Type{FloatingPointExpr}, x::RealExpr) = begin + val = isnothing(x.value) ? 0.0 : x.value # Default to 0.0 if value is `Nothing` or `Missing` + op = :identity + children = AbstractExpr[] + name = "convert_from_real_$(x.name)" + eb = 11 + sb = 53 + rounding_mode = :round_nearest_ties_to_even # Default rounding mode + return FloatingPointExpr(op, children, val, name, true, eb, sb, rounding_mode) +end diff --git a/src/Satisfiability.jl b/src/Satisfiability.jl index dabb5f7..0c4dd0b 100644 --- a/src/Satisfiability.jl +++ b/src/Satisfiability.jl @@ -9,6 +9,7 @@ export AbstractExpr, RealExpr, AbstractBitVectorExpr, BitVectorExpr, + FloatingPointExpr, isequal, hash, # required by isequal (?) in, # specialize to use isequal instead of == @@ -99,6 +100,8 @@ include("IntExpr.jl") include("BitVectorExpr.jl") +include("FloatingPointExpr.jl") + include("uninterpreted_func.jl") # include @satvariable later because we need some functions from BitVector to declare that type diff --git a/test/floating_point_tests.jl b/test/floating_point_tests.jl new file mode 100644 index 0000000..c865ca4 --- /dev/null +++ b/test/floating_point_tests.jl @@ -0,0 +1,165 @@ +@testitem "Floating Point" begin + + using Satisfiability + using Satisfiability: round_float, fp_literal + + # Constructor Tests + @testset "Constructor" begin + # Test for FloatingPointExpr with specific eb and sb + fp64 = FloatingPointExpr("fp64", value=2.5, eb=11, sb=53, rounding_mode=:RNE) + @test fp64.value == 2.5 + @test fp64.name == "fp64" + @test fp64.eb == 11 + @test fp64.sb == 53 + + @test fp64.rounding_mode == :round_nearest_ties_to_even + fp32 = FloatingPointExpr("fp32", value=3.14, eb=8, sb=24, rounding_mode=:RNE) + @test fp32.value == 3.14 + @test fp32.name == "fp32" + @test fp32.eb == 8 + @test fp32.sb == 24 + + @test fp32.rounding_mode == :round_nearest_ties_to_even + fp16 = FloatingPointExpr("fp16", value=1.0, eb=5, sb=11, rounding_mode=:RNE) + @test fp16.value == 1.0 + @test fp16.name == "fp16" + @test fp16.eb == 5 + @test fp16.sb == 11 + @test fp16.rounding_mode == :round_nearest_ties_to_even + end + + # Floating-point type synonyms tests (Float16, Float32, Float64, Float128) + @testset "Floating-Point Type Synonyms" begin + fp64 = FloatingPointExpr("fp64", value=2.5, eb=11, sb=53, rounding_mode=:RNE) + @test fp64.value == 2.5 + @test fp64.eb == 11 + @test fp64.sb == 53 + # Float16 + fp16 = FloatingPointExpr("fp16", value=1.0, eb=5, sb=11, rounding_mode=:RNE) + @test fp16.value == 1.0 + @test fp16.eb == 5 + @test fp16.sb == 11 + # Float32 + fp32 = FloatingPointExpr("fp32", value=3.14, eb=8, sb=24, rounding_mode=:RNE) + @test fp32.value == 3.14 + @test fp32.eb == 8 + @test fp32.sb == 24 + # Float128 + fp128 = FloatingPointExpr("fp128", value=1.23456789, eb=15, sb=113, rounding_mode=:RNE) + @test fp128.value == 1.23456789 + @test fp128.eb == 15 + @test fp128.sb == 113 + end + + # Arithmetic Operations + @testset "Arithmetic Operations" begin + fp1 = FloatingPointExpr("fp1", value=2.5, eb=11, sb=53) + fp2 = FloatingPointExpr("fp2", value=1.5, eb=11, sb=53) + + # Addition + fp_add = fp1 + fp2 + @test fp_add.value == 4.0 + @test fp_add.name == "add" + + # Subtraction + fp_sub = fp1 - fp2 + @test fp_sub.value == 1.0 + @test fp_sub.name == "sub" + + # Multiplication + fp_mul = fp1 * fp2 + @test fp_mul.value == 4.0 + @test fp_mul.name == "mul" + + # Division + fp_div = fp1 / fp2 + @test fp_div.value ≈ 2.0 + @test fp_div.name == "div" + end + + @testset "Special Values" begin + fp_zero = FloatingPointExpr("zero", value=0.0, eb=11, sb=53) + @test fp_zero.value == 0.0 + @test fp_zero.name == "zero" + + fp_nan = FloatingPointExpr("NaN", value=NaN, eb=11, sb=53) + @test isnan(fp_nan.value) + @test fp_nan.name == "NaN" + + fp_inf = FloatingPointExpr("positive_infinity", value=Inf, eb=11, sb=53) + @test fp_inf.value == Inf + @test fp_inf.name == "positive_infinity" + + fp_ninf = FloatingPointExpr("negative_infinity", value=-Inf, eb=11, sb=53) + @test fp_ninf.value == -Inf + @test fp_ninf.name == "negative_infinity" + end + + @testset "Conversion Tests" begin + int_expr = IntExpr("int1") + real_expr = RealExpr("real1") + # IntExpr to FloatingPointExpr + fp_from_int = convert(FloatingPointExpr, int_expr) + @test fp_from_int.value == 0.0 + # RealExpr to FloatingPointExpr + fp_from_real = convert(FloatingPointExpr, real_expr) + @test fp_from_real.value == 0.0 + end + + @testset "Rounding Tests" begin + fp = FloatingPointExpr("fp_round", value=1.23456789, eb=11, sb=53, rounding_mode=:RTP) + rounded_value = round_float(fp.value, fp.rounding_mode) + + @test rounded_value == ceil(1.23456789) # Round toward positive + fp = FloatingPointExpr("fp_round", value=1.23456789, eb=11, sb=53, rounding_mode=:RTN) + rounded_value = round_float(fp.value, fp.rounding_mode) + + @test rounded_value == floor(1.23456789) # Round toward negative + fp = FloatingPointExpr("fp_round", value=1.23456789, eb=11, sb=53, rounding_mode=:RTZ) + rounded_value = round_float(fp.value, fp.rounding_mode) + @test rounded_value == trunc(1.23456789) # Round toward zero + end + + @testset "Conversion Tests" begin + int_expr = IntExpr("int1") + real_expr = RealExpr("real1") + fp_from_int = convert(FloatingPointExpr, int_expr) + @test fp_from_int.value == 0.0 + @test fp_from_int.name == "convert_from_int_int1" + + fp_from_real = convert(FloatingPointExpr, real_expr) + @test fp_from_real.value == 0.0 + @test fp_from_real.name == "convert_from_real_real1" + end + + @testset "Test fp_literal - Positive" begin + fp_expr = fp_literal(false, 10, 12345, 11, 53) + expected_value = ldexp(12345 / (1 << (53 - 1)), 10 - (1 << (11 - 1)) + 1) + @test fp_expr.value ≈ expected_value + @test fp_expr.eb == 11 + @test fp_expr.sb == 53 + end + + @testset "Test fp_literal - Negative" begin + fp_expr = fp_literal(true, 10, 12345, 11, 53) + expected_value = -ldexp(12345 / (1 << (53 - 1)), 10 - (1 << (11 - 1)) + 1) + @test fp_expr.value ≈ expected_value + @test fp_expr.eb == 11 + @test fp_expr.sb == 53 + end + + @testset "Test fp_literal - Zero" begin + fp_expr = fp_literal(false, 0, 0, 11, 53) + @test fp_expr.value == 0.0 + @test fp_expr.eb == 11 + @test fp_expr.sb == 53 + end + + @testset "Test fp_literal - Small" begin + fp_expr = fp_literal(false, -10, 123, 11, 53) + expected_value = ldexp(123 / (1 << (53 - 1)), -10 - (1 << (11 - 1)) + 1) + @test fp_expr.value ≈ expected_value + @test fp_expr.eb == 11 + @test fp_expr.sb == 53 + end +end