Skip to content

Commit c8c010a

Browse files
authored
Support FP in Bitwuzla (#33)
1 parent ee5de56 commit c8c010a

File tree

6 files changed

+337
-93
lines changed

6 files changed

+337
-93
lines changed

ksmt-bitwuzla/src/main/kotlin/org/ksmt/solver/bitwuzla/KBitwuzlaContext.kt

+3-3
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ open class KBitwuzlaContext : AutoCloseable {
3030
private val bitwuzlaSorts = hashMapOf<BitwuzlaSort, KSort>()
3131
private val declSorts = hashMapOf<KDecl<*>, BitwuzlaSort>()
3232
private val bitwuzlaConstants = hashMapOf<BitwuzlaTerm, KDecl<*>>()
33-
private val bvValues = hashMapOf<BitwuzlaTerm, KExpr<*>>()
33+
private val bitwuzlaValues = hashMapOf<BitwuzlaTerm, KExpr<*>>()
3434

3535
operator fun get(expr: KExpr<*>): BitwuzlaTerm? = expressions[expr]
3636
operator fun get(sort: KSort): BitwuzlaSort? = sorts[sort]
@@ -67,7 +67,7 @@ open class KBitwuzlaContext : AutoCloseable {
6767
* expressions.
6868
* */
6969
fun saveInternalizedValue(expr: KExpr<*>, term: BitwuzlaTerm) {
70-
bvValues[term] = expr
70+
bitwuzlaValues[term] = expr
7171
}
7272

7373
fun findConvertedExpr(expr: BitwuzlaTerm): KExpr<*>? = bitwuzlaExpressions[expr]
@@ -78,7 +78,7 @@ open class KBitwuzlaContext : AutoCloseable {
7878
fun convertSort(sort: BitwuzlaSort, converter: (BitwuzlaSort) -> KSort): KSort =
7979
convert(sorts, bitwuzlaSorts, sort, converter)
8080

81-
fun convertValue(value: BitwuzlaTerm): KExpr<*>? = bvValues[value]
81+
fun convertValue(value: BitwuzlaTerm): KExpr<*>? = bitwuzlaValues[value]
8282

8383
// Constant is known only if it was previously internalized
8484
fun convertConstantIfKnown(term: BitwuzlaTerm): KDecl<*>? = bitwuzlaConstants[term]

ksmt-bitwuzla/src/main/kotlin/org/ksmt/solver/bitwuzla/KBitwuzlaExprConverter.kt

+195-8
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@ import org.ksmt.decl.KDecl
55
import org.ksmt.decl.KFuncDecl
66
import org.ksmt.expr.KBitVecValue
77
import org.ksmt.expr.KExpr
8+
import org.ksmt.expr.KFpRoundingMode
89
import org.ksmt.expr.transformer.KNonRecursiveTransformer
910
import org.ksmt.expr.transformer.KTransformerBase
11+
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaBitVector
1012
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaKind
1113
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaSort
1214
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTerm
@@ -16,8 +18,11 @@ import org.ksmt.sort.KArraySort
1618
import org.ksmt.sort.KBoolSort
1719
import org.ksmt.sort.KBv1Sort
1820
import org.ksmt.sort.KBvSort
21+
import org.ksmt.sort.KFpRoundingModeSort
22+
import org.ksmt.sort.KFpSort
1923
import org.ksmt.sort.KSort
2024

25+
@Suppress("LargeClass")
2126
open class KBitwuzlaExprConverter(
2227
private val ctx: KContext,
2328
val bitwuzlaCtx: KBitwuzlaContext
@@ -79,6 +84,14 @@ open class KBitwuzlaExprConverter(
7984
val size = Native.bitwuzlaSortBvGetSize(sort)
8085
mkBvSort(size.toUInt())
8186
}
87+
Native.bitwuzlaSortIsFp(sort) -> {
88+
val exponent = Native.bitwuzlaSortFpGetExpSize(sort)
89+
val significand = Native.bitwuzlaSortFpGetSigSize(sort)
90+
mkFpSort(exponent.toUInt(), significand.toUInt())
91+
}
92+
Native.bitwuzlaSortIsRm(sort) -> {
93+
mkFpRoundingModeSort()
94+
}
8295
else -> TODO("Given sort $sort is not supported yet")
8396
}
8497
}
@@ -204,7 +217,7 @@ open class KBitwuzlaExprConverter(
204217
BitwuzlaKind.BITWUZLA_KIND_FP_TO_FP_FROM_SBV,
205218
BitwuzlaKind.BITWUZLA_KIND_FP_TO_FP_FROM_UBV,
206219
BitwuzlaKind.BITWUZLA_KIND_FP_TO_SBV,
207-
BitwuzlaKind.BITWUZLA_KIND_FP_TO_UBV -> TODO("FP are not supported yet")
220+
BitwuzlaKind.BITWUZLA_KIND_FP_TO_UBV -> convertFpExpr(expr, kind)
208221

209222
// unsupported
210223
BitwuzlaKind.BITWUZLA_NUM_KINDS,
@@ -278,11 +291,24 @@ open class KBitwuzlaExprConverter(
278291
Native.bitwuzlaTermIsBv(expr) -> bitwuzlaCtx.convertValue(expr) ?: run {
279292
convertBvValue(expr)
280293
}
281-
Native.bitwuzlaTermIsFp(expr) -> TODO("FP are not supported yet")
294+
Native.bitwuzlaTermIsFp(expr) -> bitwuzlaCtx.convertValue(expr) ?: run {
295+
convertFpValue(expr)
296+
}
297+
Native.bitwuzlaTermIsRm(expr) -> bitwuzlaCtx.convertValue(expr) ?: run {
298+
convertRmValue(expr)
299+
}
282300
else -> TODO("unsupported value $expr")
283301
}
284302
}
285303

304+
private fun BitwuzlaBitVector.getBit(idx: Int): Boolean =
305+
Native.bitwuzlaBvBitsGetBit(this, idx) != 0
306+
307+
private fun BooleanArray.toReversedBinaryString() = String(CharArray(size) { charIdx ->
308+
val bitIdx = size - 1 - charIdx
309+
if (this[bitIdx]) '1' else '0'
310+
})
311+
286312
private fun KContext.convertBvValue(expr: BitwuzlaTerm): KBitVecValue<KBvSort> {
287313
val size = Native.bitwuzlaTermBvGetSize(expr)
288314

@@ -297,12 +323,8 @@ open class KBitwuzlaExprConverter(
297323
val numericValue = Native.bitwuzlaBvBitsToUInt64(nativeBits).toULong()
298324
numericValue.toString(radix = 2).padStart(size, '0')
299325
} else {
300-
val bitChars = CharArray(size) { charIdx ->
301-
val bitIdx = size - 1 - charIdx
302-
val bit = Native.bitwuzlaBvBitsGetBit(nativeBits, bitIdx) != 0
303-
if (bit) '1' else '0'
304-
}
305-
String(bitChars)
326+
val bits = BooleanArray(size) { nativeBits.getBit(it) }
327+
bits.toReversedBinaryString()
306328
}
307329

308330
mkBv(bits, size.toUInt())
@@ -316,6 +338,76 @@ open class KBitwuzlaExprConverter(
316338
return convertedValue
317339
}
318340

341+
private fun KContext.convertFpValue(expr: BitwuzlaTerm): KExpr<KFpSort> {
342+
val sort = Native.bitwuzlaTermGetSort(expr).convertSort() as KFpSort
343+
344+
val convertedValue = if (Native.bitwuzlaTermIsFpValue(expr)) {
345+
// convert Fp value from native representation
346+
val nativeBits = Native.bitwuzlaFpConstNodeGetBits(bitwuzlaCtx.bitwuzla, expr)
347+
val nativeBitsSize = Native.bitwuzlaBvBitsGetWidth(nativeBits)
348+
val size = (sort.exponentBits + sort.significandBits).toInt()
349+
check(size == nativeBitsSize) {
350+
"Fp size mismatch, expr size $size, native size $nativeBitsSize "
351+
}
352+
353+
when (size) {
354+
Double.SIZE_BITS -> {
355+
val numericValue = Native.bitwuzlaBvBitsToUInt64(nativeBits)
356+
mkFp(Double.fromBits(numericValue), sort)
357+
}
358+
Float.SIZE_BITS -> {
359+
val numericValue = Native.bitwuzlaBvBitsToUInt64(nativeBits)
360+
mkFp(Float.fromBits(numericValue.toInt()), sort)
361+
}
362+
else -> {
363+
val exponentSize = sort.exponentBits
364+
val significandSize = sort.significandBits
365+
366+
val signBit = nativeBits.getBit(size - 1)
367+
val exponentBits = BooleanArray(exponentSize.toInt()) {
368+
val lowestBit = size - 1 - exponentSize.toInt()
369+
nativeBits.getBit(lowestBit + it)
370+
}
371+
val significandBits = BooleanArray(significandSize.toInt() - 1) {
372+
nativeBits.getBit(it)
373+
}
374+
375+
mkFpFromBvExpr(
376+
sign = mkBv(signBit),
377+
exponent = mkBv(exponentBits.toReversedBinaryString(), exponentSize),
378+
significand = mkBv(significandBits.toReversedBinaryString(), significandSize - 1u)
379+
)
380+
}
381+
}
382+
383+
} else {
384+
val value = Native.bitwuzlaGetFpValue(bitwuzlaCtx.bitwuzla, expr)
385+
386+
@Suppress("UNCHECKED_CAST")
387+
mkFpFromBvExpr(
388+
sign = mkBv(value.sign, sizeBits = 1u) as KExpr<KBv1Sort>,
389+
exponent = mkBv(value.exponent, value.exponent.length.toUInt()),
390+
significand = mkBv(value.significand, value.significand.length.toUInt())
391+
)
392+
}
393+
394+
bitwuzlaCtx.saveInternalizedValue(convertedValue, expr)
395+
396+
return convertedValue
397+
}
398+
399+
private fun KContext.convertRmValue(expr: BitwuzlaTerm): KExpr<KFpRoundingModeSort> {
400+
val kind = when {
401+
Native.bitwuzlaTermIsRmValueRne(expr) -> KFpRoundingMode.RoundNearestTiesToEven
402+
Native.bitwuzlaTermIsRmValueRna(expr) -> KFpRoundingMode.RoundNearestTiesToAway
403+
Native.bitwuzlaTermIsRmValueRtp(expr) -> KFpRoundingMode.RoundTowardPositive
404+
Native.bitwuzlaTermIsRmValueRtn(expr) -> KFpRoundingMode.RoundTowardNegative
405+
Native.bitwuzlaTermIsRmValueRtz(expr) -> KFpRoundingMode.RoundTowardZero
406+
else -> error("Unexpected rounding mode")
407+
}
408+
return mkFpRoundingModeExpr(kind)
409+
}
410+
319411
open fun KContext.convertBoolExpr(expr: BitwuzlaTerm, kind: BitwuzlaKind): ExprConversionResult = when (kind) {
320412
BitwuzlaKind.BITWUZLA_KIND_BV_AND, BitwuzlaKind.BITWUZLA_KIND_AND -> expr.convertList(::mkAnd)
321413
BitwuzlaKind.BITWUZLA_KIND_BV_OR, BitwuzlaKind.BITWUZLA_KIND_OR -> expr.convertList(::mkOr)
@@ -424,6 +516,94 @@ open class KBitwuzlaExprConverter(
424516
else -> error("unexpected BV kind $kind")
425517
}
426518

519+
@Suppress("LongMethod", "ComplexMethod")
520+
open fun convertFpExpr(expr: BitwuzlaTerm, kind: BitwuzlaKind): ExprConversionResult = when (kind) {
521+
BitwuzlaKind.BITWUZLA_KIND_FP_ABS -> expr.convert(ctx::mkFpAbsExpr)
522+
BitwuzlaKind.BITWUZLA_KIND_FP_ADD -> expr.convert(ctx::mkFpAddExpr)
523+
BitwuzlaKind.BITWUZLA_KIND_FP_SUB -> expr.convert(ctx::mkFpSubExpr)
524+
BitwuzlaKind.BITWUZLA_KIND_FP_MUL -> expr.convert(ctx::mkFpMulExpr)
525+
BitwuzlaKind.BITWUZLA_KIND_FP_FMA -> expr.convert(ctx::mkFpFusedMulAddExpr)
526+
BitwuzlaKind.BITWUZLA_KIND_FP_DIV -> expr.convert(ctx::mkFpDivExpr)
527+
BitwuzlaKind.BITWUZLA_KIND_FP_REM -> expr.convert(ctx::mkFpRemExpr)
528+
BitwuzlaKind.BITWUZLA_KIND_FP_MAX -> expr.convert(ctx::mkFpMaxExpr)
529+
BitwuzlaKind.BITWUZLA_KIND_FP_MIN -> expr.convert(ctx::mkFpMinExpr)
530+
BitwuzlaKind.BITWUZLA_KIND_FP_NEG -> expr.convert(ctx::mkFpNegationExpr)
531+
BitwuzlaKind.BITWUZLA_KIND_FP_RTI -> expr.convert(ctx::mkFpRoundToIntegralExpr)
532+
BitwuzlaKind.BITWUZLA_KIND_FP_SQRT -> expr.convert(ctx::mkFpSqrtExpr)
533+
BitwuzlaKind.BITWUZLA_KIND_FP_IS_INF -> expr.convert(ctx::mkFpIsInfiniteExpr)
534+
BitwuzlaKind.BITWUZLA_KIND_FP_IS_NAN -> expr.convert(ctx::mkFpIsNaNExpr)
535+
BitwuzlaKind.BITWUZLA_KIND_FP_IS_NORMAL -> expr.convert(ctx::mkFpIsNormalExpr)
536+
BitwuzlaKind.BITWUZLA_KIND_FP_IS_SUBNORMAL -> expr.convert(ctx::mkFpIsSubnormalExpr)
537+
BitwuzlaKind.BITWUZLA_KIND_FP_IS_NEG -> expr.convert(ctx::mkFpIsNegativeExpr)
538+
BitwuzlaKind.BITWUZLA_KIND_FP_IS_POS -> expr.convert(ctx::mkFpIsPositiveExpr)
539+
BitwuzlaKind.BITWUZLA_KIND_FP_IS_ZERO -> expr.convert(ctx::mkFpIsZeroExpr)
540+
BitwuzlaKind.BITWUZLA_KIND_FP_EQ -> expr.convert(ctx::mkFpEqualExpr)
541+
BitwuzlaKind.BITWUZLA_KIND_FP_LEQ -> expr.convert(ctx::mkFpLessOrEqualExpr)
542+
BitwuzlaKind.BITWUZLA_KIND_FP_LT -> expr.convert(ctx::mkFpLessExpr)
543+
BitwuzlaKind.BITWUZLA_KIND_FP_GEQ -> expr.convert(ctx::mkFpGreaterOrEqualExpr)
544+
BitwuzlaKind.BITWUZLA_KIND_FP_GT -> expr.convert(ctx::mkFpGreaterExpr)
545+
BitwuzlaKind.BITWUZLA_KIND_FP_TO_SBV ->
546+
expr.convert { rm: KExpr<KFpRoundingModeSort>, value: KExpr<KFpSort> ->
547+
val bvSize = Native.bitwuzlaTermGetIndices(expr).single()
548+
ctx.mkFpToBvExpr(rm, value, bvSize, isSigned = true)
549+
}
550+
BitwuzlaKind.BITWUZLA_KIND_FP_TO_UBV ->
551+
expr.convert { rm: KExpr<KFpRoundingModeSort>, value: KExpr<KFpSort> ->
552+
val bvSize = Native.bitwuzlaTermGetIndices(expr).single()
553+
ctx.mkFpToBvExpr(rm, value, bvSize, isSigned = false)
554+
}
555+
BitwuzlaKind.BITWUZLA_KIND_FP_TO_FP_FROM_SBV ->
556+
expr.convert { rm: KExpr<KFpRoundingModeSort>, value: KExpr<KBvSort> ->
557+
val sort = Native.bitwuzlaTermGetSort(expr).convertSort() as KFpSort
558+
ctx.mkBvToFpExpr(sort, rm, value, signed = true)
559+
}
560+
BitwuzlaKind.BITWUZLA_KIND_FP_TO_FP_FROM_UBV ->
561+
expr.convert { rm: KExpr<KFpRoundingModeSort>, value: KExpr<KBvSort> ->
562+
val sort = Native.bitwuzlaTermGetSort(expr).convertSort() as KFpSort
563+
ctx.mkBvToFpExpr(sort, rm, value, signed = false)
564+
}
565+
BitwuzlaKind.BITWUZLA_KIND_FP_FP -> expr.convert(ctx::mkFpFromBvExpr)
566+
BitwuzlaKind.BITWUZLA_KIND_FP_TO_FP_FROM_BV -> {
567+
val indices = Native.bitwuzlaTermGetIndices(expr)
568+
check(indices.size == 2) { "unexpected fp-from-bv indices: $indices" }
569+
val (exponentSize, significandSize) = indices
570+
val bvValue = Native.bitwuzlaTermGetChildren(expr).single()
571+
val bvSize = Native.bitwuzlaTermBvGetSize(bvValue)
572+
check(bvSize == exponentSize + significandSize) {
573+
"unexpected bv size in fp-from-bv: bv-size $bvSize, exp $exponentSize, significand $significandSize"
574+
}
575+
576+
// rewrite expression as fp_fp
577+
val signBv = Native.bitwuzlaMkTerm1Indexed2(
578+
bitwuzlaCtx.bitwuzla, BitwuzlaKind.BITWUZLA_KIND_BV_EXTRACT,
579+
bvValue,
580+
bvSize - 1, bvSize - 1
581+
)
582+
val exponentBv = Native.bitwuzlaMkTerm1Indexed2(
583+
bitwuzlaCtx.bitwuzla, BitwuzlaKind.BITWUZLA_KIND_BV_EXTRACT,
584+
bvValue,
585+
bvSize - 2, bvSize - 1 - exponentSize
586+
)
587+
val significandBv = Native.bitwuzlaMkTerm1Indexed2(
588+
bitwuzlaCtx.bitwuzla, BitwuzlaKind.BITWUZLA_KIND_BV_EXTRACT,
589+
bvValue,
590+
significandSize - 2, 0
591+
)
592+
val rewritedExpr = Native.bitwuzlaMkTerm3(
593+
bitwuzlaCtx.bitwuzla, BitwuzlaKind.BITWUZLA_KIND_FP_FP,
594+
signBv, exponentBv, significandBv
595+
)
596+
597+
convertNativeExpr(rewritedExpr)
598+
}
599+
BitwuzlaKind.BITWUZLA_KIND_FP_TO_FP_FROM_FP ->
600+
expr.convert { rm: KExpr<KFpRoundingModeSort>, value: KExpr<KFpSort> ->
601+
val sort = Native.bitwuzlaTermGetSort(expr).convertSort() as KFpSort
602+
ctx.mkFpToFpExpr(sort, rm, value)
603+
}
604+
else -> error("unexpected Fp kind $kind")
605+
}
606+
427607
private fun <T : KDecl<*>> generateDecl(term: BitwuzlaTerm, generator: (String) -> T): T {
428608
val name = Native.bitwuzlaTermGetSymbol(term)
429609
val declName = name ?: generateBitwuzlaSymbol(term)
@@ -751,6 +931,13 @@ open class KBitwuzlaExprConverter(
751931
return convert(args, op)
752932
}
753933

934+
inline fun <T : KSort, A0 : KSort, A1 : KSort, A2 : KSort, A3 : KSort> BitwuzlaTerm.convert(
935+
op: (KExpr<A0>, KExpr<A1>, KExpr<A2>, KExpr<A3>) -> KExpr<T>
936+
): ExprConversionResult {
937+
val args = Native.bitwuzlaTermGetChildren(this)
938+
return convert(args, op)
939+
}
940+
754941
inline fun <T : KSort, A : KSort> BitwuzlaTerm.convertList(
755942
op: (List<KExpr<A>>) -> KExpr<T>
756943
): ExprConversionResult {

0 commit comments

Comments
 (0)