Skip to content

Commit 8a56bec

Browse files
authored
Add SMT FP theory (#80)
1 parent 3b627f5 commit 8a56bec

File tree

3 files changed

+248
-0
lines changed

3 files changed

+248
-0
lines changed
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: xdsl-smt "%s" | xdsl-smt -t=smt | filecheck "%s"
2+
// RUN: xdsl-smt "%s" -t=smt | z3 -in
3+
4+
"builtin.module"() ({
5+
6+
%three = smt.bv.constant #smt.bv<3> : !smt.bv<8>
7+
%four = smt.bv.constant #smt.bv<4> : !smt.bv<10>
8+
%zero = smt.bv.constant #smt.bv<0> : !smt.bv<1>
9+
%random_fp = "smt.fp.constant"(%zero, %three, %four) : (!smt.bv<1>, !smt.bv<8>, !smt.bv<10>) -> !smt.fp<8,11>
10+
%pzero = "smt.fp.pzero"() : () -> !smt.fp<8,11>
11+
%nzero = "smt.fp.nzero"() : () -> !smt.fp<8,11>
12+
%ninf = "smt.fp.ninf"() : () -> !smt.fp<8,11>
13+
%pinf = "smt.fp.pinf"() : () -> !smt.fp<8,11>
14+
%nan = "smt.fp.nan"() : () -> !smt.fp<8,11>
15+
16+
%eq_inf = "smt.eq"(%pinf, %ninf) : (!smt.fp<8,11>, !smt.fp<8,11>) -> !smt.bool
17+
"smt.assert"(%eq_inf) : (!smt.bool) -> ()
18+
// CHECK: (assert (= (_ +oo 8 11) (_ -oo 8 11)))
19+
20+
%eq_zero = "smt.eq"(%pzero, %nzero) : (!smt.fp<8,11>, !smt.fp<8,11>) -> !smt.bool
21+
"smt.assert"(%eq_zero) : (!smt.bool) -> ()
22+
// CHECK: (assert (= (_ +zero 8 11) (_ -zero 8 11)))
23+
24+
%eq = "smt.eq"(%random_fp, %nan) : (!smt.fp<8,11>, !smt.fp<8,11>) -> !smt.bool
25+
"smt.assert"(%eq) : (!smt.bool) -> ()
26+
// CHECK: (assert (= (fp (_ bv0 1) (_ bv3 8) (_ bv4 10)) (_ NaN 8 11)))
27+
28+
29+
30+
}) : () -> ()

xdsl_smt/cli/xdsl_smt.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from xdsl.dialects.test import Test
1313
from xdsl.dialects.memref import MemRef
1414
from xdsl_smt.dialects.smt_array_dialect import SMTArray
15+
from xdsl_smt.dialects.smt_floatingpoint_dialect import SMTFloatingPointDialect
1516
from xdsl_smt.dialects.smt_int_dialect import SMTIntDialect
1617
from xdsl_smt.dialects.effects.effect import EffectDialect
1718
from xdsl_smt.dialects.effects.ub_effect import UBEffectDialect
@@ -79,6 +80,9 @@ def register_all_dialects(self):
7980
self.ctx.register_dialect(SMTArray.name, lambda: SMTArray)
8081
self.ctx.register_dialect(SMTTensorDialect.name, lambda: SMTTensorDialect)
8182
self.ctx.register_dialect(SMTUtilsDialect.name, lambda: SMTUtilsDialect)
83+
self.ctx.register_dialect(
84+
SMTFloatingPointDialect.name, lambda: SMTFloatingPointDialect
85+
)
8286
self.ctx.register_dialect(EffectDialect.name, lambda: EffectDialect)
8387
self.ctx.register_dialect(UBEffectDialect.name, lambda: UBEffectDialect)
8488
self.ctx.register_dialect(MemoryEffectDialect.name, lambda: MemoryEffectDialect)
@@ -99,6 +103,7 @@ def register_all_dialects(self):
99103
self.ctx.load_registered_dialect(SMTBitVectorDialect.name)
100104
self.ctx.load_registered_dialect(SMTUtilsDialect.name)
101105
self.ctx.load_registered_dialect(SMTArray.name)
106+
self.ctx.load_registered_dialect(SMTFloatingPointDialect.name)
102107

103108
def register_all_passes(self):
104109
super().register_all_passes()
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
from __future__ import annotations
2+
from abc import abstractmethod
3+
4+
from xdsl.printer import Printer
5+
6+
from xdsl.parser import AttrParser
7+
8+
from xdsl.dialects.builtin import (
9+
IntAttr,
10+
)
11+
from typing import IO, Sequence
12+
from xdsl_smt.dialects.smt_bitvector_dialect import BitVectorType
13+
14+
from xdsl.ir import (
15+
ParametrizedAttribute,
16+
TypeAttribute,
17+
OpResult,
18+
Attribute,
19+
Operation,
20+
Dialect,
21+
)
22+
23+
from xdsl.irdl import (
24+
operand_def,
25+
result_def,
26+
Operand,
27+
irdl_attr_definition,
28+
irdl_op_definition,
29+
IRDLOperation,
30+
)
31+
from xdsl.utils.exceptions import VerifyException
32+
from ..traits.effects import Pure
33+
34+
from ..traits.smt_printer import (
35+
SMTLibOp,
36+
SimpleSMTLibOp,
37+
SMTLibSort,
38+
SMTConversionCtx,
39+
)
40+
41+
42+
@irdl_attr_definition
43+
class FloatingPointType(ParametrizedAttribute, SMTLibSort, TypeAttribute):
44+
"""
45+
eb defines the number of bits in the exponent;
46+
sb defines the number of bits in the significand, *including * the hidden bit.
47+
"""
48+
49+
name = "smt.fp"
50+
eb: IntAttr
51+
sb: IntAttr
52+
53+
def __init__(self, eb: int | IntAttr, sb: int | IntAttr):
54+
if isinstance(eb, int):
55+
eb = IntAttr(eb)
56+
if isinstance(sb, int):
57+
sb = IntAttr(sb)
58+
super().__init__(eb, sb)
59+
60+
def verify(self) -> None:
61+
super().verify()
62+
if self.eb.data <= 0:
63+
raise VerifyException(
64+
"FloatingPointType exponent must be strictly greater "
65+
f"than zero, got {self.eb.data}"
66+
)
67+
if self.sb.data <= 0:
68+
raise VerifyException(
69+
"FloatingPointType significand must be strictly greater "
70+
f"than zero, got {self.sb.data}"
71+
)
72+
73+
@classmethod
74+
def parse_parameters(cls, parser: AttrParser) -> Sequence[Attribute]:
75+
with parser.in_angle_brackets():
76+
eb = parser.parse_integer(allow_boolean=False, allow_negative=False)
77+
parser.parse_characters(",")
78+
sb = parser.parse_integer(allow_boolean=False, allow_negative=False)
79+
80+
return IntAttr(eb), IntAttr(sb)
81+
82+
def print_parameters(self, printer: Printer) -> None:
83+
printer.print_string(f"<{self.eb.data}, {self.sb.data}>")
84+
85+
def print_sort_to_smtlib(self, stream: IO[str]) -> None:
86+
print(f"(_ FloatingPoint {self.eb.data} {self.sb.data})", file=stream, end="")
87+
88+
89+
"""
90+
These correspond to the IEEE binary16, binary32, binary64 and binary128 formats.
91+
"""
92+
float16 = FloatingPointType(5, 11)
93+
float32 = FloatingPointType(8, 24)
94+
float64 = FloatingPointType(11, 53)
95+
float128 = FloatingPointType(15, 113)
96+
97+
98+
def getWidthFromBitVectorType(typ: Attribute):
99+
if isinstance(typ, BitVectorType):
100+
return typ.width.data
101+
raise ValueError("Expected a BitVector type")
102+
103+
104+
################################################################################
105+
# FP Value Constructors #
106+
################################################################################
107+
108+
109+
@irdl_op_definition
110+
class ConstantOp(IRDLOperation, Pure, SimpleSMTLibOp):
111+
"""
112+
FP literals as bit string triples, with the leading bit for the significand not represented (hidden bit)
113+
(fp (_ BitVec 1) (_ BitVec eb) (_ BitVec i) (_ FloatingPoint eb sb))
114+
where eb and sb are numerals greater than 1 and i = sb - 1.
115+
"""
116+
117+
name = "smt.fp.constant"
118+
lsb = operand_def(BitVectorType)
119+
eb = operand_def(BitVectorType)
120+
rsb = operand_def(BitVectorType)
121+
result = result_def(FloatingPointType)
122+
123+
def __init__(self, lsb: Operand, eb: Operand, rsb: Operand):
124+
eb_len = getWidthFromBitVectorType(eb.type)
125+
sb_len = getWidthFromBitVectorType(lsb.type) + getWidthFromBitVectorType(
126+
rsb.type
127+
)
128+
super().__init__(
129+
result_types=[FloatingPointType(eb_len, sb_len)], operands=[lsb, eb, rsb]
130+
)
131+
132+
def verify_(self):
133+
if not (1 == getWidthFromBitVectorType(self.lsb.type)):
134+
raise VerifyException("Expected leading significant bit with width 1")
135+
136+
def op_name(self) -> str:
137+
return "fp"
138+
139+
140+
class SpecialConstantOp(IRDLOperation, Pure, SMTLibOp):
141+
"""
142+
This class is an abstract class for -/+infinity, -/+zero and NaN
143+
"""
144+
145+
res: OpResult = result_def(FloatingPointType)
146+
147+
def __init__(self, eb: int | IntAttr, sb: int | IntAttr):
148+
super().__init__(result_types=[FloatingPointType(eb, sb)])
149+
150+
def print_expr_to_smtlib(self, stream: IO[str], ctx: SMTConversionCtx) -> None:
151+
assert isinstance(self, Operation)
152+
assert isinstance(self.res.type, FloatingPointType)
153+
print(f"(_ {self.constant_name()}", file=stream, end="")
154+
print(f" {self.res.type.eb.data} {self.res.type.sb.data})", file=stream, end="")
155+
156+
@abstractmethod
157+
def constant_name(self) -> str:
158+
"""Expression name when printed in SMTLib."""
159+
...
160+
161+
162+
@irdl_op_definition
163+
class PositiveInfinityOp(SpecialConstantOp):
164+
name = "smt.fp.pinf"
165+
166+
def constant_name(self) -> str:
167+
return "+oo"
168+
169+
170+
@irdl_op_definition
171+
class NegativeInfinityOp(SpecialConstantOp):
172+
name = "smt.fp.ninf"
173+
174+
def constant_name(self) -> str:
175+
return "-oo"
176+
177+
178+
@irdl_op_definition
179+
class PositiveZeroOp(SpecialConstantOp):
180+
name = "smt.fp.pzero"
181+
182+
def constant_name(self) -> str:
183+
return "+zero"
184+
185+
186+
@irdl_op_definition
187+
class NegativeZeroOp(SpecialConstantOp):
188+
name = "smt.fp.nzero"
189+
190+
def constant_name(self) -> str:
191+
return "-zero"
192+
193+
194+
@irdl_op_definition
195+
class NaNOp(SpecialConstantOp):
196+
name = "smt.fp.nan"
197+
198+
def constant_name(self) -> str:
199+
return "NaN"
200+
201+
202+
SMTFloatingPointDialect = Dialect(
203+
"smt.fp",
204+
[
205+
ConstantOp,
206+
PositiveZeroOp,
207+
NegativeZeroOp,
208+
PositiveInfinityOp,
209+
NegativeInfinityOp,
210+
NaNOp,
211+
],
212+
[FloatingPointType],
213+
)

0 commit comments

Comments
 (0)