|
1 | 1 | from dataclasses import dataclass, field |
2 | 2 | from typing import Callable, ClassVar, Sequence |
3 | | -from xdsl.dialects.builtin import ModuleOp, IntegerType, StringAttr, UnitAttr |
| 3 | +from xdsl.dialects.builtin import ModuleOp, IntegerType, StringAttr, UnitAttr, ArrayAttr |
4 | 4 | from xdsl.dialects.pdl import ( |
5 | 5 | ApplyNativeConstraintOp, |
6 | 6 | ApplyNativeRewriteOp, |
@@ -290,8 +290,12 @@ def match_and_rewrite(self, op: ReplaceOp, rewriter: PatternRewriter): |
290 | 290 | replacing_value = replacing_values[0] |
291 | 291 |
|
292 | 292 | if self.refinement is None: |
| 293 | + replaced_types = op.attributes["__replaced_types"] |
| 294 | + replacing_types = op.attributes["__replacing_types"] |
| 295 | + assert isa(replaced_types, ArrayAttr[Attribute]) |
| 296 | + assert isa(replacing_types, ArrayAttr[Attribute]) |
293 | 297 | refinement = find_refinement_semantics( |
294 | | - replaced_value.type, replacing_value.type |
| 298 | + replaced_types.data[0], replacing_types.data[0] |
295 | 299 | ) |
296 | 300 | else: |
297 | 301 | refinement = self.refinement |
@@ -529,6 +533,46 @@ def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter): |
529 | 533 | ), "Operations used as computations in PDL should not have effects" |
530 | 534 |
|
531 | 535 |
|
| 536 | +def annotate_replace_op(pattern: PatternOp): |
| 537 | + """ |
| 538 | + Annotate all `pdl.replace` operations in the given pattern with the types of the |
| 539 | + replaced and replacing values. |
| 540 | + """ |
| 541 | + values_to_types: dict[SSAValue, tuple[Attribute, ...]] = {} |
| 542 | + for op in pattern.walk(): |
| 543 | + if isinstance(op, OperandOp): |
| 544 | + if op.value_type is None: |
| 545 | + raise Exception("Cannot handle non-typed operands") |
| 546 | + owner = op.value_type.owner |
| 547 | + assert isinstance(owner, TypeOp) |
| 548 | + assert owner.constantType is not None |
| 549 | + values_to_types[op.value] = (owner.constantType,) |
| 550 | + if isinstance(op, OperationOp): |
| 551 | + values: list[Attribute] = [] |
| 552 | + for type_value in op.type_values: |
| 553 | + owner = type_value.owner |
| 554 | + assert isinstance(owner, TypeOp) |
| 555 | + assert owner.constantType is not None |
| 556 | + values.append(owner.constantType) |
| 557 | + values_to_types[op.op] = tuple(values) |
| 558 | + if isinstance(op, ResultOp): |
| 559 | + values_to_types[op.val] = ( |
| 560 | + values_to_types[op.parent_][op.index.value.data], |
| 561 | + ) |
| 562 | + if isinstance(op, ReplaceOp): |
| 563 | + replaced_value_types = values_to_types[op.op_value] |
| 564 | + if len(op.repl_values) != 0: |
| 565 | + replacing_value_types: list[Attribute] = [] |
| 566 | + for repl_value in op.repl_values: |
| 567 | + replacing_value_types.append(values_to_types[repl_value][0]) |
| 568 | + replacing_types = tuple(replacing_value_types) |
| 569 | + else: |
| 570 | + assert op.repl_operation is not None |
| 571 | + replacing_types = values_to_types[op.repl_operation] |
| 572 | + op.attributes["__replaced_types"] = ArrayAttr(replaced_value_types) |
| 573 | + op.attributes["__replacing_types"] = ArrayAttr(replacing_types) |
| 574 | + |
| 575 | + |
532 | 576 | @dataclass |
533 | 577 | class PDLToSMTLowerer: |
534 | 578 | native_rewrites: dict[ |
@@ -621,6 +665,7 @@ def lower_to_smt(self, module: ModuleOp, ctx: Context) -> None: |
621 | 665 | new_state_op = DeclareConstOp(StateType()) |
622 | 666 | Rewriter.insert_op(new_state_op, insert_point) |
623 | 667 | rewrite_context = PDLToSMTRewriteContext(new_state_op.res, new_state_op.res) |
| 668 | + annotate_replace_op(pattern) |
624 | 669 |
|
625 | 670 | walker = PatternRewriteWalker( |
626 | 671 | GreedyRewritePatternApplier( |
|
0 commit comments