Skip to content

Commit e0fffe5

Browse files
authored
Add rounding mode in FP theory (#83)
1 parent 860261f commit e0fffe5

File tree

2 files changed

+173
-1
lines changed

2 files changed

+173
-1
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// RUN: xdsl-smt "%s" | xdsl-smt -t=smt | filecheck "%s"
2+
// RUN: xdsl-smt "%s" -t=smt | z3 -in
3+
4+
"builtin.module"() ({
5+
6+
%rne_full = "smt.fp.round_nearest_ties_to_even"() : () -> !smt.fp.rounding_mode
7+
%rne = "smt.fp.rne"() : () -> !smt.fp.rounding_mode
8+
%rna_full = "smt.fp.round_nearest_ties_to_away"() : () -> !smt.fp.rounding_mode
9+
%rna = "smt.fp.rna"() : () -> !smt.fp.rounding_mode
10+
%rtp_full = "smt.fp.round_toward_positive"() : () -> !smt.fp.rounding_mode
11+
%rtp = "smt.fp.rtp"() : () -> !smt.fp.rounding_mode
12+
%rtn_full = "smt.fp.round_toward_negative"() : () -> !smt.fp.rounding_mode
13+
%rtn = "smt.fp.rtn"() : () -> !smt.fp.rounding_mode
14+
%rtz_full = "smt.fp.round_toward_zero"() : () -> !smt.fp.rounding_mode
15+
%rtz = "smt.fp.rtz"() : () -> !smt.fp.rounding_mode
16+
17+
18+
%eq_rne = "smt.eq"(%rne_full, %rne) : (!smt.fp.rounding_mode, !smt.fp.rounding_mode) -> !smt.bool
19+
%eq_rna = "smt.eq"(%rna_full, %rna) : (!smt.fp.rounding_mode, !smt.fp.rounding_mode) -> !smt.bool
20+
%eq_rtp = "smt.eq"(%rtp_full, %rtp) : (!smt.fp.rounding_mode, !smt.fp.rounding_mode) -> !smt.bool
21+
%eq_rtn = "smt.eq"(%rtn_full, %rtn) : (!smt.fp.rounding_mode, !smt.fp.rounding_mode) -> !smt.bool
22+
%eq_rtz = "smt.eq"(%rtz_full, %rtz) : (!smt.fp.rounding_mode, !smt.fp.rounding_mode) -> !smt.bool
23+
24+
25+
"smt.assert"(%eq_rne) : (!smt.bool) -> ()
26+
// CHECK: (assert (= roundNearestTiesToEven RNE))
27+
"smt.assert"(%eq_rna) : (!smt.bool) -> ()
28+
// CHECK: (assert (= roundNearestTiesToAway RNA))
29+
"smt.assert"(%eq_rtp) : (!smt.bool) -> ()
30+
// CHECK: (assert (= roundTowardPositive RTP))
31+
"smt.assert"(%eq_rtn) : (!smt.bool) -> ()
32+
// CHECK: (assert (= roundTowardNegative RTN))
33+
"smt.assert"(%eq_rtz) : (!smt.bool) -> ()
34+
// CHECK: (assert (= roundTowardZero RTZ))
35+
}) : () -> ()

xdsl_smt/dialects/smt_floatingpoint_dialect.py

Lines changed: 138 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,132 @@
3939
)
4040

4141

42+
@irdl_attr_definition
43+
class RoundingModeType(ParametrizedAttribute, SMTLibSort, TypeAttribute):
44+
"""
45+
Defines Rounding Mode of FP operations, it includes following constants and their abbreviated version
46+
:funs ((roundNearestTiesToEven RoundingMode) (RNE RoundingMode)
47+
(roundNearestTiesToAway RoundingMode) (RNA RoundingMode)
48+
(roundTowardPositive RoundingMode) (RTP RoundingMode)
49+
(roundTowardNegative RoundingMode) (RTN RoundingMode)
50+
(roundTowardZero RoundingMode) (RTZ RoundingMode)
51+
)
52+
"""
53+
54+
name = "smt.fp.rounding_mode"
55+
56+
def __init__(self):
57+
super().__init__()
58+
59+
def print_sort_to_smtlib(self, stream: IO[str]) -> None:
60+
print(f"RoundingMode", file=stream, end="")
61+
62+
63+
class RunningModeConstantOp(IRDLOperation, Pure, SMTLibOp):
64+
"""
65+
This class is an abstract class for all RoundingMode constants
66+
:funs ((roundNearestTiesToEven RoundingMode) (RNE RoundingMode)
67+
(roundNearestTiesToAway RoundingMode) (RNA RoundingMode)
68+
(roundTowardPositive RoundingMode) (RTP RoundingMode)
69+
(roundTowardNegative RoundingMode) (RTN RoundingMode)
70+
(roundTowardZero RoundingMode) (RTZ RoundingMode)
71+
)
72+
"""
73+
74+
res: OpResult = result_def(RoundingModeType)
75+
76+
def __init__(self):
77+
super().__init__(result_types=[RoundingModeType()])
78+
79+
def print_expr_to_smtlib(self, stream: IO[str], ctx: SMTConversionCtx) -> None:
80+
print(f"{self.constant_name()}", file=stream, end="")
81+
82+
@abstractmethod
83+
def constant_name(self) -> str:
84+
"""RoundingMode name when printed in SMTLib."""
85+
...
86+
87+
88+
@irdl_op_definition
89+
class RoundNearestTiesToEvenOp(RunningModeConstantOp):
90+
name = "smt.fp.round_nearest_ties_to_even"
91+
92+
def constant_name(self) -> str:
93+
return "roundNearestTiesToEven"
94+
95+
96+
@irdl_op_definition
97+
class RNEOp(RunningModeConstantOp):
98+
name = "smt.fp.rne"
99+
100+
def constant_name(self) -> str:
101+
return "RNE"
102+
103+
104+
@irdl_op_definition
105+
class RoundNearestTiesToAwayOp(RunningModeConstantOp):
106+
name = "smt.fp.round_nearest_ties_to_away"
107+
108+
def constant_name(self) -> str:
109+
return "roundNearestTiesToAway"
110+
111+
112+
@irdl_op_definition
113+
class RNAOp(RunningModeConstantOp):
114+
name = "smt.fp.rna"
115+
116+
def constant_name(self) -> str:
117+
return "RNA"
118+
119+
120+
@irdl_op_definition
121+
class RoundTowardPositiveOp(RunningModeConstantOp):
122+
name = "smt.fp.round_toward_positive"
123+
124+
def constant_name(self) -> str:
125+
return "roundTowardPositive"
126+
127+
128+
@irdl_op_definition
129+
class RTPOp(RunningModeConstantOp):
130+
name = "smt.fp.rtp"
131+
132+
def constant_name(self) -> str:
133+
return "RTP"
134+
135+
136+
@irdl_op_definition
137+
class RoundTowardNegativeOp(RunningModeConstantOp):
138+
name = "smt.fp.round_toward_negative"
139+
140+
def constant_name(self) -> str:
141+
return "roundTowardNegative"
142+
143+
144+
@irdl_op_definition
145+
class RTNOp(RunningModeConstantOp):
146+
name = "smt.fp.rtn"
147+
148+
def constant_name(self) -> str:
149+
return "RTN"
150+
151+
152+
@irdl_op_definition
153+
class RoundTowardZeroOp(RunningModeConstantOp):
154+
name = "smt.fp.round_toward_zero"
155+
156+
def constant_name(self) -> str:
157+
return "roundTowardZero"
158+
159+
160+
@irdl_op_definition
161+
class RTZOp(RunningModeConstantOp):
162+
name = "smt.fp.rtz"
163+
164+
def constant_name(self) -> str:
165+
return "RTZ"
166+
167+
42168
@irdl_attr_definition
43169
class FloatingPointType(ParametrizedAttribute, SMTLibSort, TypeAttribute):
44170
"""
@@ -208,6 +334,17 @@ def constant_name(self) -> str:
208334
PositiveInfinityOp,
209335
NegativeInfinityOp,
210336
NaNOp,
337+
# Rounding Mode constants
338+
RoundNearestTiesToEvenOp,
339+
RNEOp,
340+
RoundNearestTiesToAwayOp,
341+
RNAOp,
342+
RoundTowardPositiveOp,
343+
RTPOp,
344+
RoundTowardNegativeOp,
345+
RTNOp,
346+
RoundTowardZeroOp,
347+
RTZOp,
211348
],
212-
[FloatingPointType],
349+
[FloatingPointType, RoundingModeType],
213350
)

0 commit comments

Comments
 (0)