diff --git a/tests/filecheck/dialects/fp-theory/roundingmode.mlir b/tests/filecheck/dialects/fp-theory/roundingmode.mlir new file mode 100644 index 00000000..17ec6a56 --- /dev/null +++ b/tests/filecheck/dialects/fp-theory/roundingmode.mlir @@ -0,0 +1,35 @@ +// RUN: xdsl-smt "%s" | xdsl-smt -t=smt | filecheck "%s" +// RUN: xdsl-smt "%s" -t=smt | z3 -in + +"builtin.module"() ({ + + %rne_full = "smt.fp.round_nearest_ties_to_even"() : () -> !smt.fp.rounding_mode + %rne = "smt.fp.rne"() : () -> !smt.fp.rounding_mode + %rna_full = "smt.fp.round_nearest_ties_to_away"() : () -> !smt.fp.rounding_mode + %rna = "smt.fp.rna"() : () -> !smt.fp.rounding_mode + %rtp_full = "smt.fp.round_toward_positive"() : () -> !smt.fp.rounding_mode + %rtp = "smt.fp.rtp"() : () -> !smt.fp.rounding_mode + %rtn_full = "smt.fp.round_toward_negative"() : () -> !smt.fp.rounding_mode + %rtn = "smt.fp.rtn"() : () -> !smt.fp.rounding_mode + %rtz_full = "smt.fp.round_toward_zero"() : () -> !smt.fp.rounding_mode + %rtz = "smt.fp.rtz"() : () -> !smt.fp.rounding_mode + + + %eq_rne = "smt.eq"(%rne_full, %rne) : (!smt.fp.rounding_mode, !smt.fp.rounding_mode) -> !smt.bool + %eq_rna = "smt.eq"(%rna_full, %rna) : (!smt.fp.rounding_mode, !smt.fp.rounding_mode) -> !smt.bool + %eq_rtp = "smt.eq"(%rtp_full, %rtp) : (!smt.fp.rounding_mode, !smt.fp.rounding_mode) -> !smt.bool + %eq_rtn = "smt.eq"(%rtn_full, %rtn) : (!smt.fp.rounding_mode, !smt.fp.rounding_mode) -> !smt.bool + %eq_rtz = "smt.eq"(%rtz_full, %rtz) : (!smt.fp.rounding_mode, !smt.fp.rounding_mode) -> !smt.bool + + + "smt.assert"(%eq_rne) : (!smt.bool) -> () + // CHECK: (assert (= roundNearestTiesToEven RNE)) + "smt.assert"(%eq_rna) : (!smt.bool) -> () + // CHECK: (assert (= roundNearestTiesToAway RNA)) + "smt.assert"(%eq_rtp) : (!smt.bool) -> () + // CHECK: (assert (= roundTowardPositive RTP)) + "smt.assert"(%eq_rtn) : (!smt.bool) -> () + // CHECK: (assert (= roundTowardNegative RTN)) + "smt.assert"(%eq_rtz) : (!smt.bool) -> () + // CHECK: (assert (= roundTowardZero RTZ)) +}) : () -> () diff --git a/xdsl_smt/dialects/smt_floatingpoint_dialect.py b/xdsl_smt/dialects/smt_floatingpoint_dialect.py index d3a0c60d..acf2ae1f 100644 --- a/xdsl_smt/dialects/smt_floatingpoint_dialect.py +++ b/xdsl_smt/dialects/smt_floatingpoint_dialect.py @@ -39,6 +39,132 @@ ) +@irdl_attr_definition +class RoundingModeType(ParametrizedAttribute, SMTLibSort, TypeAttribute): + """ + Defines Rounding Mode of FP operations, it includes following constants and their abbreviated version + :funs ((roundNearestTiesToEven RoundingMode) (RNE RoundingMode) + (roundNearestTiesToAway RoundingMode) (RNA RoundingMode) + (roundTowardPositive RoundingMode) (RTP RoundingMode) + (roundTowardNegative RoundingMode) (RTN RoundingMode) + (roundTowardZero RoundingMode) (RTZ RoundingMode) + ) + """ + + name = "smt.fp.rounding_mode" + + def __init__(self): + super().__init__() + + def print_sort_to_smtlib(self, stream: IO[str]) -> None: + print(f"RoundingMode", file=stream, end="") + + +class RunningModeConstantOp(IRDLOperation, Pure, SMTLibOp): + """ + This class is an abstract class for all RoundingMode constants + :funs ((roundNearestTiesToEven RoundingMode) (RNE RoundingMode) + (roundNearestTiesToAway RoundingMode) (RNA RoundingMode) + (roundTowardPositive RoundingMode) (RTP RoundingMode) + (roundTowardNegative RoundingMode) (RTN RoundingMode) + (roundTowardZero RoundingMode) (RTZ RoundingMode) + ) + """ + + res: OpResult = result_def(RoundingModeType) + + def __init__(self): + super().__init__(result_types=[RoundingModeType()]) + + def print_expr_to_smtlib(self, stream: IO[str], ctx: SMTConversionCtx) -> None: + print(f"{self.constant_name()}", file=stream, end="") + + @abstractmethod + def constant_name(self) -> str: + """RoundingMode name when printed in SMTLib.""" + ... + + +@irdl_op_definition +class RoundNearestTiesToEvenOp(RunningModeConstantOp): + name = "smt.fp.round_nearest_ties_to_even" + + def constant_name(self) -> str: + return "roundNearestTiesToEven" + + +@irdl_op_definition +class RNEOp(RunningModeConstantOp): + name = "smt.fp.rne" + + def constant_name(self) -> str: + return "RNE" + + +@irdl_op_definition +class RoundNearestTiesToAwayOp(RunningModeConstantOp): + name = "smt.fp.round_nearest_ties_to_away" + + def constant_name(self) -> str: + return "roundNearestTiesToAway" + + +@irdl_op_definition +class RNAOp(RunningModeConstantOp): + name = "smt.fp.rna" + + def constant_name(self) -> str: + return "RNA" + + +@irdl_op_definition +class RoundTowardPositiveOp(RunningModeConstantOp): + name = "smt.fp.round_toward_positive" + + def constant_name(self) -> str: + return "roundTowardPositive" + + +@irdl_op_definition +class RTPOp(RunningModeConstantOp): + name = "smt.fp.rtp" + + def constant_name(self) -> str: + return "RTP" + + +@irdl_op_definition +class RoundTowardNegativeOp(RunningModeConstantOp): + name = "smt.fp.round_toward_negative" + + def constant_name(self) -> str: + return "roundTowardNegative" + + +@irdl_op_definition +class RTNOp(RunningModeConstantOp): + name = "smt.fp.rtn" + + def constant_name(self) -> str: + return "RTN" + + +@irdl_op_definition +class RoundTowardZeroOp(RunningModeConstantOp): + name = "smt.fp.round_toward_zero" + + def constant_name(self) -> str: + return "roundTowardZero" + + +@irdl_op_definition +class RTZOp(RunningModeConstantOp): + name = "smt.fp.rtz" + + def constant_name(self) -> str: + return "RTZ" + + @irdl_attr_definition class FloatingPointType(ParametrizedAttribute, SMTLibSort, TypeAttribute): """ @@ -208,6 +334,17 @@ def constant_name(self) -> str: PositiveInfinityOp, NegativeInfinityOp, NaNOp, + # Rounding Mode constants + RoundNearestTiesToEvenOp, + RNEOp, + RoundNearestTiesToAwayOp, + RNAOp, + RoundTowardPositiveOp, + RTPOp, + RoundTowardNegativeOp, + RTNOp, + RoundTowardZeroOp, + RTZOp, ], - [FloatingPointType], + [FloatingPointType, RoundingModeType], )