diff --git a/src/kirin/ir/nodes/region.py b/src/kirin/ir/nodes/region.py index 8cbd2d9d2..4e16f9028 100644 --- a/src/kirin/ir/nodes/region.py +++ b/src/kirin/ir/nodes/region.py @@ -146,10 +146,25 @@ def __hash__(self) -> int: def clone(self, ssamap: dict[SSAValue, SSAValue] | None = None) -> Region: """Clone a region. This will clone all blocks and statements in the region. `SSAValue` defined outside the region will not be cloned unless provided in `ssamap`. + + Note: + Blocks must be in definition order (e.g. reverse post-order of the CFG) + so that every statement result is cloned before it is referenced. + Use :class:`kirin.rewrite.SortBlocks` to ensure this after passes + that may reorder blocks. """ ret = Region() successor_map: dict[Block, Block] = {} _ssamap = ssamap or {} + + # Collect all SSA values defined inside this region (statement results). + # Block args are handled separately below. + in_region_defs: set[SSAValue] = set() + for block in self.blocks: + for stmt in block.stmts: + in_region_defs.update(stmt.results) + + # First pass: create cloned blocks and block args (order doesn't matter). for block in self.blocks: new_block = Block() ret.blocks.append(new_block) @@ -158,12 +173,24 @@ def clone(self, ssamap: dict[SSAValue, SSAValue] | None = None) -> Region: new_arg = new_block.args.append_from(arg.type, arg.name) _ssamap[arg] = new_arg - # update statements + def _map_arg(arg: SSAValue) -> SSAValue: + if arg in _ssamap: + return _ssamap[arg] + if arg in in_region_defs: + raise ValueError( + f"Region.clone: in-region SSA value {arg!r} used before " + f"its defining statement was cloned. This indicates a " + f"block ordering issue — run SortBlocks before cloning." + ) + # Value defined outside the region — keep as-is. + return arg + + # Second pass: clone statements in block-list order. for block in self.blocks: for stmt in block.stmts: new_stmt = stmt.from_stmt( stmt, - args=[_ssamap.get(arg, arg) for arg in stmt.args], + args=[_map_arg(arg) for arg in stmt.args], regions=[region.clone(_ssamap) for region in stmt.regions], successors=[ successor_map[successor] for successor in stmt.successors diff --git a/src/kirin/rewrite/__init__.py b/src/kirin/rewrite/__init__.py index a82d05c9d..e0a0c642f 100644 --- a/src/kirin/rewrite/__init__.py +++ b/src/kirin/rewrite/__init__.py @@ -12,4 +12,5 @@ from .compactify import CFGCompactify as CFGCompactify from .wrap_const import WrapConst as WrapConst from .call2invoke import Call2Invoke as Call2Invoke +from .sort_blocks import SortBlocks as SortBlocks from .type_assert import InlineTypeAssert as InlineTypeAssert diff --git a/src/kirin/rewrite/compactify.py b/src/kirin/rewrite/compactify.py index 23ec271ee..e942c189c 100644 --- a/src/kirin/rewrite/compactify.py +++ b/src/kirin/rewrite/compactify.py @@ -7,6 +7,7 @@ from kirin.rewrite.walk import Walk from kirin.rewrite.chain import Chain from kirin.rewrite.fixpoint import Fixpoint +from kirin.rewrite.sort_blocks import SortBlocks @dataclass @@ -257,10 +258,16 @@ class CompactifyRegion(RewriteRule): def __init__(self, cfg: CFG): self.cfg = cfg - self.rule = Fixpoint( - Chain( - DeadBlock(cfg), Walk(DuplicatedBranch()), SkipBlock(cfg), CFGEdge(cfg) - ) + self.rule = Chain( + Fixpoint( + Chain( + DeadBlock(cfg), + Walk(DuplicatedBranch()), + SkipBlock(cfg), + CFGEdge(cfg), + ) + ), + SortBlocks(cfg), ) def rewrite(self, node: ir.IRNode) -> RewriteResult: diff --git a/src/kirin/rewrite/sort_blocks.py b/src/kirin/rewrite/sort_blocks.py new file mode 100644 index 000000000..4279ff218 --- /dev/null +++ b/src/kirin/rewrite/sort_blocks.py @@ -0,0 +1,59 @@ +from dataclasses import dataclass + +from kirin import ir +from kirin.rewrite.abc import RewriteRule, RewriteResult +from kirin.analysis.cfg import CFG + + +@dataclass +class SortBlocks(RewriteRule): + """Reorder blocks in a region to reverse post-order of the CFG. + + RPO guarantees that in well-formed SSA, a block's dominator is visited + before the block itself, so statement results appear before their uses + in block-list order. This is required for correct ``Region.clone()`` + and benefits any pass that iterates blocks sequentially. + """ + + cfg: CFG + + def rewrite_Region(self, node: ir.Region) -> RewriteResult: + # NOTE: relies on self.cfg being up-to-date. When used inside + # CompactifyRegion, prior rules (DeadBlock, CFGEdge, etc.) mutate + # the shared CFG's successors/predecessors dicts in place. + successors = self.cfg.successors + + visited: set[ir.Block] = set() + post_order: list[ir.Block] = [] + + if self.cfg.entry is not None: + stack: list[tuple[ir.Block, bool]] = [(self.cfg.entry, False)] + while stack: + block, returning = stack.pop() + if returning: + post_order.append(block) + continue + if block in visited: + continue + visited.add(block) + stack.append((block, True)) + for succ in successors.get(block, ()): + if succ not in visited: + stack.append((succ, False)) + + post_order.reverse() + + # Append unreachable blocks in their original order. + block_set = set(post_order) + for block in node.blocks: + if block not in block_set: + post_order.append(block) + + if list(node.blocks) == post_order: + return RewriteResult() + + # Reorder in place — blocks are already attached to the region, + # so we update the internal list and index directly. + node._blocks[:] = post_order + node._block_idx = {block: i for i, block in enumerate(post_order)} + return RewriteResult(has_done_something=True) diff --git a/test/ir/test_region.py b/test/ir/test_region.py index ba2f90bb0..2aa7d5139 100644 --- a/test/ir/test_region.py +++ b/test/ir/test_region.py @@ -1,4 +1,9 @@ -from kirin.prelude import basic_no_opt +import pytest + +from kirin.prelude import basic, ilist, basic_no_opt +from kirin.analysis.cfg import CFG +from kirin.rewrite.sort_blocks import SortBlocks +from kirin.passes.aggressive.fold import Fold @basic_no_opt @@ -13,3 +18,180 @@ def test_region_clone(): assert factorial.callable_region.clone().is_structurally_equal( factorial.callable_region ) + + +@basic +def _leaf(a: bool, b: bool, x: int): + if a: + u = x + 1 + else: + u = x + 2 + if b: + v = u + 3 + else: + v = u + 4 + return v + + +@basic +def _level0(flag0: bool, flag1: bool, x: int): + base = _leaf(flag0, flag1, x) + base2 = _leaf(flag1, flag0, x + 1) + if flag0: + mix = base + base2 + else: + mix = base2 + base + if flag0: + out = mix + 30 + else: + out = mix + 40 + + def fn(y: int): + return y + out + + mapped = ilist.map(fn, ilist.range(3)) + return mapped[0] + out + + +@basic +def _level1(flag0: bool, flag1: bool, x: int): + base = _level0(flag0, flag1, x) + base2 = _level0(flag1, flag0, x + 2) + if flag0: + mix = base + base2 + else: + mix = base2 + base + if flag0: + out = mix + 31 + else: + out = mix + 41 + + def fn(y: int): + return y + out + + mapped = ilist.map(fn, ilist.range(3)) + return mapped[0] + out + + +@basic +def _fold_target(flag0: bool, flag1: bool, x: int): + return _level1(flag0, flag1, x) + + +def _has_unordered_edges(method): + """True if any stmt arg is owned by a stmt in a later block (by index).""" + stmt_block = {} + for bi, block in enumerate(method.callable_region.blocks): + for stmt in block.stmts: + stmt_block[stmt] = bi + for bi, block in enumerate(method.callable_region.blocks): + for stmt in block.stmts: + for arg in stmt.args: + owner = getattr(arg, "owner", None) + owner_bi = stmt_block.get(owner) + if owner_bi is not None and owner_bi > bi: + return True + return False + + +def _all_owners_in_region(region): + """Check every operand in region is owned by a stmt/block inside it.""" + region_stmts = set() + region_blocks = set() + for block in region.blocks: + region_blocks.add(block) + for stmt in block.stmts: + region_stmts.add(stmt) + + for block in region.blocks: + for stmt in block.stmts: + for arg in stmt.args: + owner = getattr(arg, "owner", None) + if owner is None: + continue + # BlockArgument owner is a Block (check first — Block also has parent) + if hasattr(owner, "stmts"): + if owner not in region_blocks: + return False + # ResultValue owner is a Statement + elif hasattr(owner, "parent"): + if owner not in region_stmts: + return False + return True + + +def test_region_clone_after_aggressive_fold(): + """Region.clone must remap all SSA values even when blocks are out of definition order.""" + mt = _fold_target.similar() + fold = Fold(mt.dialects) + + # Run fold until we get unordered edges or it converges + for _ in range(8): + result = fold.unsafe_run(mt) + if _has_unordered_edges(mt): + break + if not result.has_done_something: + break + + # Regardless of whether we got unordered edges, clone must be self-contained + cloned = mt.callable_region.clone() + assert _all_owners_in_region( + cloned + ), "Region.clone produced operands owned by statements outside the cloned region" + assert cloned.is_structurally_equal(mt.callable_region) + + +def _scramble_blocks(region): + """Reverse block order (keeping entry block first) to create out-of-order layout.""" + blocks = list(region.blocks) + if len(blocks) <= 2: + return False + scrambled = [blocks[0]] + list(reversed(blocks[1:])) + region._blocks[:] = scrambled + region._block_idx = {block: i for i, block in enumerate(scrambled)} + return True + + +@basic +def _branchy(flag: bool, x: int): + if flag: + y = x + 1 + else: + y = x + 2 + return y + 3 + + +def test_sort_blocks_fixes_scrambled_region(): + """SortBlocks must restore RPO order after blocks are scrambled.""" + mt = _branchy.similar() + region = mt.callable_region + original_order = list(region.blocks) + + assert _scramble_blocks(region) + assert list(region.blocks) != original_order + + cfg = CFG(region) + result = SortBlocks(cfg).rewrite_Region(region) + assert result.has_done_something + # After sorting, clone must produce a self-contained region. + cloned = region.clone() + assert _all_owners_in_region(cloned) + + +def test_region_clone_raises_on_unsorted_blocks(): + """Region.clone must raise ValueError when blocks are out of definition order.""" + mt = _fold_target.similar() + fold = Fold(mt.dialects) + + # Fold to produce a complex CFG, then scramble blocks. + for _ in range(8): + result = fold.unsafe_run(mt) + if not result.has_done_something: + break + + region = mt.callable_region + if not _scramble_blocks(region): + pytest.skip("Region has too few blocks to scramble") + + with pytest.raises(ValueError, match="in-region SSA value"): + region.clone()