Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions tests/filecheck/dialects/fp-theory/roundingmode.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// RUN: xdsl-smt "%s" | xdsl-smt -t=smt | filecheck "%s"
// RUN: xdsl-smt "%s" -t=smt | z3 -in

"builtin.module"() ({

%rne_full = "smt.fp.round_nearest_ties_to_even"() : () -> !smt.fp.rounding_mode
%rne = "smt.fp.rne"() : () -> !smt.fp.rounding_mode
%rna_full = "smt.fp.round_nearest_ties_to_away"() : () -> !smt.fp.rounding_mode
%rna = "smt.fp.rna"() : () -> !smt.fp.rounding_mode
%rtp_full = "smt.fp.round_toward_positive"() : () -> !smt.fp.rounding_mode
%rtp = "smt.fp.rtp"() : () -> !smt.fp.rounding_mode
%rtn_full = "smt.fp.round_toward_negative"() : () -> !smt.fp.rounding_mode
%rtn = "smt.fp.rtn"() : () -> !smt.fp.rounding_mode
%rtz_full = "smt.fp.round_toward_zero"() : () -> !smt.fp.rounding_mode
%rtz = "smt.fp.rtz"() : () -> !smt.fp.rounding_mode


%eq_rne = "smt.eq"(%rne_full, %rne) : (!smt.fp.rounding_mode, !smt.fp.rounding_mode) -> !smt.bool
%eq_rna = "smt.eq"(%rna_full, %rna) : (!smt.fp.rounding_mode, !smt.fp.rounding_mode) -> !smt.bool
%eq_rtp = "smt.eq"(%rtp_full, %rtp) : (!smt.fp.rounding_mode, !smt.fp.rounding_mode) -> !smt.bool
%eq_rtn = "smt.eq"(%rtn_full, %rtn) : (!smt.fp.rounding_mode, !smt.fp.rounding_mode) -> !smt.bool
%eq_rtz = "smt.eq"(%rtz_full, %rtz) : (!smt.fp.rounding_mode, !smt.fp.rounding_mode) -> !smt.bool


"smt.assert"(%eq_rne) : (!smt.bool) -> ()
// CHECK: (assert (= roundNearestTiesToEven RNE))
"smt.assert"(%eq_rna) : (!smt.bool) -> ()
// CHECK: (assert (= roundNearestTiesToAway RNA))
"smt.assert"(%eq_rtp) : (!smt.bool) -> ()
// CHECK: (assert (= roundTowardPositive RTP))
"smt.assert"(%eq_rtn) : (!smt.bool) -> ()
// CHECK: (assert (= roundTowardNegative RTN))
"smt.assert"(%eq_rtz) : (!smt.bool) -> ()
// CHECK: (assert (= roundTowardZero RTZ))
}) : () -> ()
139 changes: 138 additions & 1 deletion xdsl_smt/dialects/smt_floatingpoint_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,132 @@
)


@irdl_attr_definition
class RoundingModeType(ParametrizedAttribute, SMTLibSort, TypeAttribute):
"""
Defines Rounding Mode of FP operations, it includes following constants and their abbreviated version
:funs ((roundNearestTiesToEven RoundingMode) (RNE RoundingMode)
(roundNearestTiesToAway RoundingMode) (RNA RoundingMode)
(roundTowardPositive RoundingMode) (RTP RoundingMode)
(roundTowardNegative RoundingMode) (RTN RoundingMode)
(roundTowardZero RoundingMode) (RTZ RoundingMode)
)
"""

name = "smt.fp.rounding_mode"

def __init__(self):
super().__init__()

def print_sort_to_smtlib(self, stream: IO[str]) -> None:
print(f"RoundingMode", file=stream, end="")


class RunningModeConstantOp(IRDLOperation, Pure, SMTLibOp):
"""
This class is an abstract class for all RoundingMode constants
:funs ((roundNearestTiesToEven RoundingMode) (RNE RoundingMode)
(roundNearestTiesToAway RoundingMode) (RNA RoundingMode)
(roundTowardPositive RoundingMode) (RTP RoundingMode)
(roundTowardNegative RoundingMode) (RTN RoundingMode)
(roundTowardZero RoundingMode) (RTZ RoundingMode)
)
"""

res: OpResult = result_def(RoundingModeType)

def __init__(self):
super().__init__(result_types=[RoundingModeType()])

def print_expr_to_smtlib(self, stream: IO[str], ctx: SMTConversionCtx) -> None:
print(f"{self.constant_name()}", file=stream, end="")

@abstractmethod
def constant_name(self) -> str:
"""RoundingMode name when printed in SMTLib."""
...


@irdl_op_definition
class RoundNearestTiesToEvenOp(RunningModeConstantOp):
name = "smt.fp.round_nearest_ties_to_even"

def constant_name(self) -> str:
return "roundNearestTiesToEven"


@irdl_op_definition
class RNEOp(RunningModeConstantOp):
name = "smt.fp.rne"

def constant_name(self) -> str:
return "RNE"


@irdl_op_definition
class RoundNearestTiesToAwayOp(RunningModeConstantOp):
name = "smt.fp.round_nearest_ties_to_away"

def constant_name(self) -> str:
return "roundNearestTiesToAway"


@irdl_op_definition
class RNAOp(RunningModeConstantOp):
name = "smt.fp.rna"

def constant_name(self) -> str:
return "RNA"


@irdl_op_definition
class RoundTowardPositiveOp(RunningModeConstantOp):
name = "smt.fp.round_toward_positive"

def constant_name(self) -> str:
return "roundTowardPositive"


@irdl_op_definition
class RTPOp(RunningModeConstantOp):
name = "smt.fp.rtp"

def constant_name(self) -> str:
return "RTP"


@irdl_op_definition
class RoundTowardNegativeOp(RunningModeConstantOp):
name = "smt.fp.round_toward_negative"

def constant_name(self) -> str:
return "roundTowardNegative"


@irdl_op_definition
class RTNOp(RunningModeConstantOp):
name = "smt.fp.rtn"

def constant_name(self) -> str:
return "RTN"


@irdl_op_definition
class RoundTowardZeroOp(RunningModeConstantOp):
name = "smt.fp.round_toward_zero"

def constant_name(self) -> str:
return "roundTowardZero"


@irdl_op_definition
class RTZOp(RunningModeConstantOp):
name = "smt.fp.rtz"

def constant_name(self) -> str:
return "RTZ"


@irdl_attr_definition
class FloatingPointType(ParametrizedAttribute, SMTLibSort, TypeAttribute):
"""
Expand Down Expand Up @@ -208,6 +334,17 @@ def constant_name(self) -> str:
PositiveInfinityOp,
NegativeInfinityOp,
NaNOp,
# Rounding Mode constants
RoundNearestTiesToEvenOp,
RNEOp,
RoundNearestTiesToAwayOp,
RNAOp,
RoundTowardPositiveOp,
RTPOp,
RoundTowardNegativeOp,
RTNOp,
RoundTowardZeroOp,
RTZOp,
],
[FloatingPointType],
[FloatingPointType, RoundingModeType],
)