|
| 1 | +from xdsl_smt.dialects import smt_array_dialect as smt_array |
| 2 | + |
| 3 | +from xdsl_smt.dialects.smt_dialect import ( |
| 4 | + DeclareConstOp, |
| 5 | +) |
| 6 | +from xdsl_smt.dialects.smt_tensor_dialect import ( |
| 7 | + IndexType, |
| 8 | + SMTTensorType, |
| 9 | + TensorExtractOp, |
| 10 | +) |
| 11 | +from xdsl.dialects.builtin import ModuleOp |
| 12 | +from xdsl.ir import Attribute |
| 13 | +from xdsl.utils.hints import isa |
| 14 | +from xdsl.context import Context |
| 15 | +from xdsl.pattern_rewriter import ( |
| 16 | + GreedyRewritePatternApplier, |
| 17 | + PatternRewriteWalker, |
| 18 | + PatternRewriter, |
| 19 | + RewritePattern, |
| 20 | + op_type_rewrite_pattern, |
| 21 | +) |
| 22 | +from xdsl.passes import ModulePass |
| 23 | + |
| 24 | + |
| 25 | +def lower_tensor_type(typ: Attribute) -> Attribute: |
| 26 | + if isa(typ, SMTTensorType): |
| 27 | + result = typ.element_type |
| 28 | + index_type = IndexType |
| 29 | + for _ in typ.shape: |
| 30 | + result = smt_array.ArrayType(index_type, result) |
| 31 | + return result |
| 32 | + return typ |
| 33 | + |
| 34 | + |
| 35 | +class DeclareConstOpPattern(RewritePattern): |
| 36 | + @op_type_rewrite_pattern |
| 37 | + def match_and_rewrite(self, op: DeclareConstOp, rewriter: PatternRewriter): |
| 38 | + if isa(op.res.type, SMTTensorType): |
| 39 | + new_constant_op = DeclareConstOp(lower_tensor_type(op.res.type)) |
| 40 | + rewriter.replace_matched_op(new_constant_op) |
| 41 | + |
| 42 | + |
| 43 | +class TensorExtractOpPattern(RewritePattern): |
| 44 | + @op_type_rewrite_pattern |
| 45 | + def match_and_rewrite(self, op: TensorExtractOp, rewriter: PatternRewriter): |
| 46 | + source = op.tensor |
| 47 | + assert isinstance(source.type, smt_array.ArrayType) |
| 48 | + select_ops: list[smt_array.SelectOp] = [] |
| 49 | + for idx in op.indices: |
| 50 | + select_ops.append(smt_array.SelectOp(source, idx)) |
| 51 | + source = select_ops[-1].res |
| 52 | + rewriter.replace_matched_op(select_ops) |
| 53 | + |
| 54 | + |
| 55 | +class LowerSMTTensor(ModulePass): |
| 56 | + name = "lower-smt-tensor" |
| 57 | + |
| 58 | + def apply(self, ctx: Context, op: ModuleOp): |
| 59 | + walker = PatternRewriteWalker( |
| 60 | + GreedyRewritePatternApplier( |
| 61 | + [DeclareConstOpPattern(), TensorExtractOpPattern()] |
| 62 | + ) |
| 63 | + ) |
| 64 | + walker.rewrite_module(op) |
0 commit comments