Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 68 additions & 3 deletions src/kirin/ir/nodes/region.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,47 @@
from kirin.serialization.core.serializationunit import SerializationUnit


def _rpo_block_order(region: Region) -> list[Block]:
"""Return blocks in reverse post-order of the CFG.

RPO guarantees that in well-formed SSA, a block's dominator is visited
before the block itself, so statement results are cloned before their uses.
Unreachable blocks (not in CFG) are appended at the end in original order.
"""
from kirin.analysis.cfg import CFG

cfg = CFG(region)
successors = cfg.successors

visited: set[Block] = set()
post_order: list[Block] = []

if cfg.entry is not None:
stack: list[tuple[Block, bool]] = [(cfg.entry, False)]
while stack:
block, returning = stack.pop()
if returning:
post_order.append(block)
continue
if block in visited:
continue
visited.add(block)
stack.append((block, True)) # revisit after children
for succ in successors.get(block, ()):
if succ not in visited:
stack.append((succ, False))

post_order.reverse() # reverse post-order

# Append any unreachable blocks in their original order.
block_set = set(post_order)
for block in region.blocks:
if block not in block_set:
post_order.append(block)

return post_order


@dataclass
class RegionBlocks(MutableSequenceView[list[Block], "Region", Block]):
"""A View object that contains a list of Blocks of a Region.
Expand Down Expand Up @@ -150,6 +191,15 @@ def clone(self, ssamap: dict[SSAValue, SSAValue] | None = None) -> Region:
ret = Region()
successor_map: dict[Block, Block] = {}
_ssamap = ssamap or {}

# Collect all SSA values defined inside this region (statement results).
# Block args are handled separately below.
in_region_defs: set[SSAValue] = set()
for block in self.blocks:
for stmt in block.stmts:
in_region_defs.update(stmt.results)

# First pass: create cloned blocks and block args (order doesn't matter).
for block in self.blocks:
new_block = Block()
ret.blocks.append(new_block)
Expand All @@ -158,12 +208,27 @@ def clone(self, ssamap: dict[SSAValue, SSAValue] | None = None) -> Region:
new_arg = new_block.args.append_from(arg.type, arg.name)
_ssamap[arg] = new_arg

# update statements
for block in self.blocks:
# Compute RPO for statement cloning order.
ordered_blocks = _rpo_block_order(self)

def _map_arg(arg: SSAValue) -> SSAValue:
if arg in _ssamap:
return _ssamap[arg]
if arg in in_region_defs:
raise ValueError(
f"Region.clone: in-region SSA value {arg!r} used before "
f"its defining statement was cloned. This indicates a bug "
f"in block ordering during clone."
)
# Value defined outside the region — keep as-is.
return arg

# Second pass: clone statements in RPO order.
for block in ordered_blocks:
for stmt in block.stmts:
new_stmt = stmt.from_stmt(
stmt,
args=[_ssamap.get(arg, arg) for arg in stmt.args],
args=[_map_arg(arg) for arg in stmt.args],
regions=[region.clone(_ssamap) for region in stmt.regions],
successors=[
successor_map[successor] for successor in stmt.successors
Expand Down
124 changes: 123 additions & 1 deletion test/ir/test_region.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from kirin.prelude import basic_no_opt
from kirin.prelude import basic, ilist, basic_no_opt
from kirin.passes.aggressive.fold import Fold


@basic_no_opt
Expand All @@ -13,3 +14,124 @@ def test_region_clone():
assert factorial.callable_region.clone().is_structurally_equal(
factorial.callable_region
)


@basic
def _leaf(a: bool, b: bool, x: int):
if a:
u = x + 1
else:
u = x + 2
if b:
v = u + 3
else:
v = u + 4
return v


@basic
def _level0(flag0: bool, flag1: bool, x: int):
base = _leaf(flag0, flag1, x)
base2 = _leaf(flag1, flag0, x + 1)
if flag0:
mix = base + base2
else:
mix = base2 + base
if flag0:
out = mix + 30
else:
out = mix + 40

def fn(y: int):
return y + out

mapped = ilist.map(fn, ilist.range(3))
return mapped[0] + out


@basic
def _level1(flag0: bool, flag1: bool, x: int):
base = _level0(flag0, flag1, x)
base2 = _level0(flag1, flag0, x + 2)
if flag0:
mix = base + base2
else:
mix = base2 + base
if flag0:
out = mix + 31
else:
out = mix + 41

def fn(y: int):
return y + out

mapped = ilist.map(fn, ilist.range(3))
return mapped[0] + out


@basic
def _fold_target(flag0: bool, flag1: bool, x: int):
return _level1(flag0, flag1, x)


def _has_unordered_edges(method):
"""True if any stmt arg is owned by a stmt in a later block (by index)."""
stmt_block = {}
for bi, block in enumerate(method.callable_region.blocks):
for stmt in block.stmts:
stmt_block[stmt] = bi
for bi, block in enumerate(method.callable_region.blocks):
for stmt in block.stmts:
for arg in stmt.args:
owner = getattr(arg, "owner", None)
owner_bi = stmt_block.get(owner)
if owner_bi is not None and owner_bi > bi:
return True
return False


def _all_owners_in_region(region):
"""Check every operand in region is owned by a stmt/block inside it."""
region_stmts = set()
region_blocks = set()
for block in region.blocks:
region_blocks.add(block)
for stmt in block.stmts:
region_stmts.add(stmt)

for block in region.blocks:
for stmt in block.stmts:
for arg in stmt.args:
owner = getattr(arg, "owner", None)
if owner is None:
continue
# BlockArgument owner is a Block (check first — Block also has parent)
if hasattr(owner, "stmts"):
if owner not in region_blocks:
return False
# ResultValue owner is a Statement
elif hasattr(owner, "parent"):
if owner not in region_stmts:
return False
return True


def test_region_clone_after_aggressive_fold():
"""Region.clone must remap all SSA values even when blocks are out of definition order."""
mt = _fold_target.similar()
fold = Fold(mt.dialects)

# Run fold until we get unordered edges or it converges
for _ in range(8):
result = fold.unsafe_run(mt)
if _has_unordered_edges(mt):
break
if not result.has_done_something:
break

# Regardless of whether we got unordered edges, clone must be self-contained
cloned = mt.callable_region.clone()
assert _all_owners_in_region(
cloned
), "Region.clone produced operands owned by statements outside the cloned region"
assert cloned.is_structurally_equal(mt.callable_region)
Loading