Skip to content

Commit fc7fe35

Browse files
authored
Merge pull request #75 from upstat-io/dev
feat(arc): block merge CFG pass with select lowering, approve 4 proposals
2 parents a535e81 + 67c2a36 commit fc7fe35

208 files changed

Lines changed: 9749 additions & 2590 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
//! Phase 1: Remove blocks unreachable from the entry block.
2+
//!
3+
//! Computes reachability via DFS, builds an old→new block ID remap for
4+
//! surviving blocks, filters out dead blocks, and rewrites all block
5+
//! references in surviving terminators. Also remaps `cow_annotations`
6+
//! block indices and drops annotations for dead blocks.
7+
8+
use crate::graph::successor_block_ids;
9+
use crate::ir::{ArcBlockId, ArcFunction, ArcTerminator};
10+
11+
use super::usize_to_block_id;
12+
13+
/// Remove blocks unreachable from the entry block.
14+
pub(crate) fn compact_blocks(func: &mut ArcFunction) {
15+
let num_blocks = func.blocks.len();
16+
if num_blocks == 0 {
17+
return;
18+
}
19+
20+
// DFS reachability from entry.
21+
let mut reachable = vec![false; num_blocks];
22+
let mut stack = vec![func.entry.index()];
23+
while let Some(idx) = stack.pop() {
24+
if idx >= num_blocks || reachable[idx] {
25+
continue;
26+
}
27+
reachable[idx] = true;
28+
for succ in successor_block_ids(&func.blocks[idx].terminator) {
29+
let si = succ.index();
30+
if si < num_blocks && !reachable[si] {
31+
stack.push(si);
32+
}
33+
}
34+
}
35+
36+
// Check if all blocks are reachable — early exit.
37+
if reachable.iter().all(|&r| r) {
38+
return;
39+
}
40+
41+
// Build remap: old index → Some(new index) for reachable, None for dead.
42+
let mut remap: Vec<Option<usize>> = vec![None; num_blocks];
43+
let mut counter = 0usize;
44+
for (i, &is_reachable) in reachable.iter().enumerate() {
45+
if is_reachable {
46+
remap[i] = Some(counter);
47+
counter += 1;
48+
}
49+
}
50+
51+
// Filter to reachable blocks, assigning new sequential IDs.
52+
// We drain blocks/spans to avoid needing Default on ArcBlock.
53+
let old_blocks: Vec<_> = func.blocks.drain(..).collect();
54+
let old_spans: Vec<_> = func.spans.drain(..).collect();
55+
let mut new_blocks = Vec::with_capacity(counter);
56+
let mut new_spans = Vec::with_capacity(counter);
57+
for (i, (mut block, spans)) in old_blocks.into_iter().zip(old_spans).enumerate() {
58+
if reachable[i] {
59+
block.id = remap_to_block_id(remap[i]);
60+
new_blocks.push(block);
61+
new_spans.push(spans);
62+
}
63+
}
64+
65+
// Rewrite targets in surviving blocks.
66+
for block in &mut new_blocks {
67+
remap_terminator_targets(&mut block.terminator, &remap);
68+
}
69+
70+
func.blocks = new_blocks;
71+
func.spans = new_spans;
72+
func.entry = remap_to_block_id(remap[func.entry.index()]);
73+
func.cow_annotations.remap_block_indices(&remap);
74+
}
75+
76+
/// Convert a remap entry to an `ArcBlockId`.
77+
///
78+
/// # Panics
79+
///
80+
/// Panics if the entry is `None` (unreachable block used where
81+
/// reachable was expected) or exceeds `u32::MAX`.
82+
fn remap_to_block_id(entry: Option<usize>) -> ArcBlockId {
83+
let idx = entry.unwrap_or_else(|| panic!("block remap entry is None for a required block"));
84+
usize_to_block_id(idx)
85+
}
86+
87+
/// Rewrite all `ArcBlockId` references in a terminator using a remap table.
88+
fn remap_terminator_targets(term: &mut ArcTerminator, remap: &[Option<usize>]) {
89+
fn remap_id(id: &mut ArcBlockId, remap: &[Option<usize>]) {
90+
*id = remap_to_block_id(remap[id.index()]);
91+
}
92+
93+
match term {
94+
ArcTerminator::Return { .. } | ArcTerminator::Resume | ArcTerminator::Unreachable => {}
95+
ArcTerminator::Jump { target, .. } => remap_id(target, remap),
96+
ArcTerminator::Branch {
97+
then_block,
98+
else_block,
99+
..
100+
} => {
101+
remap_id(then_block, remap);
102+
remap_id(else_block, remap);
103+
}
104+
ArcTerminator::Switch { cases, default, .. } => {
105+
for (_, target) in cases {
106+
remap_id(target, remap);
107+
}
108+
remap_id(default, remap);
109+
}
110+
ArcTerminator::Invoke { normal, unwind, .. } => {
111+
remap_id(normal, remap);
112+
remap_id(unwind, remap);
113+
}
114+
}
115+
}
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
//! Phase 2: Downgrade trivial `Invoke` terminators to `Apply` + `Jump`.
2+
//!
3+
//! An invoke is trivial when:
4+
//! 1. `normal != unwind` (same block would route success to `Resume`)
5+
//! 2. The unwind block is empty body + `Resume` terminator + no params
6+
//! 3. The normal block has no params
7+
//! 4. The normal block has exactly one predecessor (the invoking block)
8+
//!
9+
//! The `Invoke { dst, ty, func, args, arg_ownership, normal, unwind }`
10+
//! becomes an `Apply { dst, ty, func, args, arg_ownership }` appended to
11+
//! the block body, with terminator replaced by `Jump { target: normal }`.
12+
13+
use crate::graph::compute_pred_counts;
14+
use crate::ir::{ArcFunction, ArcInstr, ArcTerminator};
15+
16+
use super::usize_to_block_id;
17+
18+
/// Convert trivial `Invoke` terminators to `Apply` + `Jump`.
19+
pub(crate) fn downgrade_trivial_invokes(func: &mut ArcFunction) {
20+
let pred_counts = compute_pred_counts(func);
21+
22+
for block_idx in 0..func.blocks.len() {
23+
// Check if this block has a trivial invoke — extract normal_idx
24+
// and apply fields if so.
25+
let Some(normal_idx) = is_trivial_invoke(func, block_idx, &pred_counts) else {
26+
continue;
27+
};
28+
29+
// Extract invoke fields. We know the terminator is Invoke from
30+
// the check above.
31+
let (dst, ty, callee, args, arg_ownership) = {
32+
let ArcTerminator::Invoke {
33+
dst,
34+
ty,
35+
func: callee,
36+
args,
37+
arg_ownership,
38+
..
39+
} = &func.blocks[block_idx].terminator
40+
else {
41+
continue;
42+
};
43+
(*dst, *ty, *callee, args.clone(), arg_ownership.clone())
44+
};
45+
46+
// Append Apply to body.
47+
func.blocks[block_idx].body.push(ArcInstr::Apply {
48+
dst,
49+
ty,
50+
func: callee,
51+
args,
52+
arg_ownership,
53+
});
54+
55+
// Append None span for the new Apply.
56+
func.spans[block_idx].push(None);
57+
58+
// Replace terminator with Jump.
59+
func.blocks[block_idx].terminator = ArcTerminator::Jump {
60+
target: usize_to_block_id(normal_idx),
61+
args: vec![],
62+
};
63+
}
64+
}
65+
66+
/// Check if a block's `Invoke` terminator is trivial and return the
67+
/// normal successor index if so.
68+
///
69+
/// Returns `None` if the block doesn't have an `Invoke`, or if any of
70+
/// the four criteria for trivial invoke downgrade are not met.
71+
fn is_trivial_invoke(func: &ArcFunction, block_idx: usize, pred_counts: &[usize]) -> Option<usize> {
72+
let ArcTerminator::Invoke { normal, unwind, .. } = &func.blocks[block_idx].terminator else {
73+
return None;
74+
};
75+
76+
// Criterion 1: normal != unwind.
77+
if normal == unwind {
78+
return None;
79+
}
80+
81+
let normal_idx = normal.index();
82+
let unwind_idx = unwind.index();
83+
84+
// Criterion 2: unwind block is trivial (empty + Resume + no params).
85+
let ub = &func.blocks[unwind_idx];
86+
if !ub.body.is_empty() || ub.terminator != ArcTerminator::Resume || !ub.params.is_empty() {
87+
return None;
88+
}
89+
90+
// Criterion 3: normal block has no params.
91+
if !func.blocks[normal_idx].params.is_empty() {
92+
return None;
93+
}
94+
95+
// Criterion 4: normal block has exactly one predecessor.
96+
if pred_counts[normal_idx] != 1 {
97+
return None;
98+
}
99+
100+
Some(normal_idx)
101+
}
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
//! Phase 4: Merge single-predecessor Jump chains until fixed point.
2+
//!
3+
//! For each block A with terminator `Jump { target: B, args }` where:
4+
//! - A != B (self-loop guard)
5+
//! - B has exactly one predecessor (A)
6+
//! - B is not the entry block
7+
//!
8+
//! Lower B's params as Let bindings (parallel-copy semantics), then
9+
//! merge B's body and spans into A.
10+
//!
11+
//! Runs to fixed point for transitive chains (A → B → C all merge into A).
12+
//! After fixed point, runs a final compaction to remove dead blocks.
13+
14+
use rustc_hash::FxHashSet;
15+
16+
use crate::graph::compute_pred_counts;
17+
use crate::ir::{ArcFunction, ArcInstr, ArcTerminator, ArcValue, ArcVarId, ValueRepr};
18+
19+
use super::compact::compact_blocks;
20+
21+
/// Merge single-predecessor Jump chains until fixed point.
22+
pub(crate) fn merge_jump_chains(func: &mut ArcFunction) {
23+
let mut dead: FxHashSet<usize> = FxHashSet::default();
24+
25+
loop {
26+
let mut changed = false;
27+
let pred_counts = compute_pred_counts(func);
28+
29+
for a_idx in 0..func.blocks.len() {
30+
if dead.contains(&a_idx) {
31+
continue;
32+
}
33+
34+
let (b_idx, jump_args) = {
35+
let ArcTerminator::Jump { target, args } = &func.blocks[a_idx].terminator else {
36+
continue;
37+
};
38+
let b_idx = target.index();
39+
40+
// Self-loop guard.
41+
if a_idx == b_idx {
42+
continue;
43+
}
44+
// B must have exactly one predecessor.
45+
if pred_counts[b_idx] != 1 {
46+
continue;
47+
}
48+
// B must not be the entry block.
49+
if b_idx == func.entry.index() {
50+
continue;
51+
}
52+
// B must not already be dead.
53+
if dead.contains(&b_idx) {
54+
continue;
55+
}
56+
57+
(b_idx, args.clone())
58+
};
59+
60+
let b_params = func.blocks[b_idx].params.clone();
61+
62+
// Arity check: Jump args must match target block params.
63+
debug_assert_eq!(
64+
b_params.len(),
65+
jump_args.len(),
66+
"Jump args/params arity mismatch: block {a_idx} → block {b_idx}",
67+
);
68+
if b_params.len() != jump_args.len() {
69+
continue;
70+
}
71+
72+
// Lower parallel-copy semantics: block params → Let bindings.
73+
lower_parallel_copy(func, a_idx, &b_params, &jump_args);
74+
75+
// Remap COW annotations: B's entries → A's coordinates.
76+
let offset = func.blocks[a_idx].body.len();
77+
func.cow_annotations.remap_block_merge(b_idx, a_idx, offset);
78+
79+
// Merge B's body into A.
80+
let b_body: Vec<ArcInstr> = func.blocks[b_idx].body.drain(..).collect();
81+
func.blocks[a_idx].body.extend(b_body);
82+
83+
// Merge B's spans into A.
84+
let b_spans: Vec<Option<ori_ir::Span>> = func.spans[b_idx].drain(..).collect();
85+
func.spans[a_idx].extend(b_spans);
86+
87+
// Replace A's terminator with B's.
88+
let b_term = std::mem::replace(
89+
&mut func.blocks[b_idx].terminator,
90+
ArcTerminator::Unreachable,
91+
);
92+
func.blocks[a_idx].terminator = b_term;
93+
94+
// Mark B as dead.
95+
dead.insert(b_idx);
96+
changed = true;
97+
}
98+
99+
if !changed {
100+
break;
101+
}
102+
}
103+
104+
// Final compaction: remove dead blocks.
105+
if !dead.is_empty() {
106+
compact_blocks(func);
107+
}
108+
}
109+
110+
/// Lower block-param parallel-copy semantics to sequential Let bindings.
111+
///
112+
/// Jump args are parallel phi inputs — all args are read before any param
113+
/// is written. When no arg aliases a target param, direct Let is safe.
114+
/// When overlap exists (e.g., swap: `Jump { args: [p1, p0] }` → params
115+
/// `[p0, p1]`), we use fresh temps to avoid clobbering.
116+
fn lower_parallel_copy(
117+
func: &mut ArcFunction,
118+
block_idx: usize,
119+
params: &[(ArcVarId, ori_types::Idx)],
120+
args: &[ArcVarId],
121+
) {
122+
if params.is_empty() {
123+
return;
124+
}
125+
126+
// Check for overlap: does any arg alias a target param?
127+
let param_vars: FxHashSet<ArcVarId> = params.iter().map(|(v, _)| *v).collect();
128+
let has_overlap = args.iter().any(|a| param_vars.contains(a));
129+
130+
if has_overlap {
131+
// Slow path: copy all args to fresh temps first, then temps to params.
132+
// Use fresh_var_repr to preserve repr metadata for ref-typed params.
133+
let temps: Vec<ArcVarId> = args
134+
.iter()
135+
.zip(params.iter())
136+
.map(|(arg, (_, ty))| {
137+
let repr = func.var_repr(*arg).unwrap_or(ValueRepr::Scalar);
138+
func.fresh_var_repr(*ty, repr)
139+
})
140+
.collect();
141+
142+
// Phase 1: args → temps.
143+
for ((&arg, temp), (_, ty)) in args.iter().zip(temps.iter()).zip(params.iter()) {
144+
func.blocks[block_idx].body.push(ArcInstr::Let {
145+
dst: *temp,
146+
ty: *ty,
147+
value: ArcValue::Var(arg),
148+
});
149+
func.spans[block_idx].push(None);
150+
}
151+
152+
// Phase 2: temps → params.
153+
for ((param_var, param_ty), temp) in params.iter().zip(temps.iter()) {
154+
func.blocks[block_idx].body.push(ArcInstr::Let {
155+
dst: *param_var,
156+
ty: *param_ty,
157+
value: ArcValue::Var(*temp),
158+
});
159+
func.spans[block_idx].push(None);
160+
}
161+
} else {
162+
// Fast path: no aliasing, direct Let is safe.
163+
for ((param_var, param_ty), &arg) in params.iter().zip(args.iter()) {
164+
func.blocks[block_idx].body.push(ArcInstr::Let {
165+
dst: *param_var,
166+
ty: *param_ty,
167+
value: ArcValue::Var(arg),
168+
});
169+
func.spans[block_idx].push(None);
170+
}
171+
}
172+
}

0 commit comments

Comments
 (0)