Skip to content

Commit 9cd093a

Browse files
authored
Merge pull request #40 from astroautomata/faster-rotate
perf: faster version of rotate_tree_in_place
2 parents d727780 + a083f10 commit 9cd093a

3 files changed

Lines changed: 343 additions & 46 deletions

File tree

symbolic_regression/benches/optim.rs

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,5 +410,50 @@ fn bench_utils(c: &mut Criterion) {
410410
group.finish();
411411
}
412412

413-
criterion_group!(benches, bench_search, bench_utils);
413+
fn bench_rotate_tree(c: &mut Criterion) {
414+
fn make_mixed_tree<R: Rng>(rng: &mut R, n_leaves: usize, unary_budget: usize) -> Vec<PNode> {
415+
let mut nodes: Vec<PNode> = Vec::with_capacity(n_leaves * 2 + unary_budget);
416+
for _ in 0..n_leaves {
417+
nodes.push(PNode::Var {
418+
feature: rng.random_range(0..16) as u16,
419+
});
420+
}
421+
let mut stack = n_leaves;
422+
let mut unary_left = unary_budget;
423+
while stack > 1 {
424+
if unary_left > 0 && rng.random_bool(0.35) {
425+
nodes.push(PNode::Op {
426+
arity: 1,
427+
op: rng.random_range(0..128) as u16,
428+
});
429+
unary_left -= 1;
430+
continue;
431+
}
432+
433+
let arity = if stack >= 3 && rng.random_bool(0.35) { 3 } else { 2 };
434+
nodes.push(PNode::Op {
435+
arity: arity as u8,
436+
op: rng.random_range(0..128) as u16,
437+
});
438+
stack = stack - arity + 1;
439+
}
440+
nodes
441+
}
442+
443+
let mut gen = StdRng::seed_from_u64(0);
444+
let mut exprs: Vec<PostfixExpr<T, Ops, D>> = (0..1024)
445+
.map(|_| PostfixExpr::new(make_mixed_tree(&mut gen, 16, 8), Vec::new(), Default::default()))
446+
.collect();
447+
let mut rng = FastRand::with_seed(0);
448+
c.bench_function("rotate_tree_in_place/mixed_arity", |b| {
449+
b.iter(|| {
450+
for expr in &mut exprs {
451+
let ok = rotate_tree_in_place(&mut rng, expr);
452+
std::hint::black_box(ok);
453+
}
454+
});
455+
});
456+
}
457+
458+
criterion_group!(benches, bench_search, bench_utils, bench_rotate_tree);
414459
criterion_main!(benches);

symbolic_regression/src/mutation_functions.rs

Lines changed: 171 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::cell::RefCell;
2+
13
use dynamic_expressions::expression::PostfixExpr;
24
use dynamic_expressions::node::PNode;
35
use dynamic_expressions::{Operators, node_utils};
@@ -270,101 +272,225 @@ pub(crate) fn swap_operands_in_place<T, Ops, const D: usize>(rng: &mut Rng, expr
270272
true
271273
}
272274

275+
struct RotateTreeScratch {
276+
sizes: Vec<usize>,
277+
stack: Vec<usize>,
278+
valid_roots: Vec<usize>,
279+
buf: Vec<PNode>,
280+
}
281+
282+
impl RotateTreeScratch {
283+
const fn new() -> Self {
284+
Self {
285+
sizes: Vec::new(),
286+
stack: Vec::new(),
287+
valid_roots: Vec::new(),
288+
buf: Vec::new(),
289+
}
290+
}
291+
}
292+
293+
std::thread_local! {
294+
static ROTATE_TREE_SCRATCH: RefCell<RotateTreeScratch> = const { RefCell::new(RotateTreeScratch::new()) };
295+
}
296+
297+
fn subtree_sizes_into(nodes: &[PNode], sizes: &mut Vec<usize>, stack: &mut Vec<usize>) {
298+
sizes.resize(nodes.len(), 0);
299+
stack.clear();
300+
stack.reserve(nodes.len());
301+
302+
for (i, n) in nodes.iter().enumerate() {
303+
match *n {
304+
PNode::Var { .. } | PNode::Const { .. } => {
305+
sizes[i] = 1;
306+
stack.push(1);
307+
}
308+
PNode::Op { arity, .. } => {
309+
let a = arity as usize;
310+
let mut sum = 1usize;
311+
for _ in 0..a {
312+
sum += stack.pop().expect("invalid postfix (stack underflow)");
313+
}
314+
sizes[i] = sum;
315+
stack.push(sum);
316+
}
317+
}
318+
}
319+
320+
debug_assert_eq!(stack.len(), 1, "invalid postfix (did not reduce to one root)");
321+
}
322+
323+
fn has_op_child(nodes: &[PNode], sizes: &[usize], root_idx: usize, arity: usize) -> bool {
324+
let mut end = root_idx;
325+
for _ in 0..arity {
326+
end = end.checked_sub(1).expect("invalid postfix (child end underflow)");
327+
let child_end = end;
328+
if matches!(nodes[child_end], PNode::Op { .. }) {
329+
return true;
330+
}
331+
end = child_end + 1 - sizes[child_end];
332+
}
333+
false
334+
}
335+
336+
fn fill_child_ranges<const D: usize>(sizes: &[usize], root_idx: usize, arity: usize, out: &mut [(usize, usize); D]) {
337+
debug_assert!(arity > 0);
338+
339+
let mut end = root_idx;
340+
for k in (0..arity).rev() {
341+
end = end.checked_sub(1).expect("invalid postfix (child end underflow)");
342+
let child_end = end;
343+
let sz = sizes[child_end];
344+
let child_start = child_end + 1 - sz;
345+
out[k] = (child_start, child_end);
346+
end = child_start;
347+
}
348+
}
349+
273350
pub fn rotate_tree_in_place<T, Ops, const D: usize>(rng: &mut Rng, expr: &mut PostfixExpr<T, Ops, D>) -> bool {
351+
ROTATE_TREE_SCRATCH.with(|cell| rotate_tree_in_place_impl(cell, rng, expr))
352+
}
353+
354+
fn rotate_tree_in_place_impl<T, Ops, const D: usize>(
355+
cell: &RefCell<RotateTreeScratch>,
356+
rng: &mut Rng,
357+
expr: &mut PostfixExpr<T, Ops, D>,
358+
) -> bool {
274359
// Match SymbolicRegression.jl's `randomly_rotate_tree!`:
275360
// pick a random rotation root where some child is an operator, then
276361
// rotate along a random internal edge (root -> pivot) using a random grandchild.
277-
let sizes = node_utils::subtree_sizes(&expr.nodes);
278-
let mut valid_roots: Vec<usize> = Vec::new();
279-
for (i, n) in expr.nodes.iter().enumerate() {
280-
let PNode::Op { arity, .. } = *n else {
362+
let mut sc = cell.borrow_mut();
363+
let nodes = &mut expr.nodes;
364+
let n = nodes.len();
365+
366+
let RotateTreeScratch {
367+
sizes,
368+
stack,
369+
valid_roots,
370+
buf,
371+
} = &mut *sc;
372+
373+
subtree_sizes_into(nodes, sizes, stack);
374+
375+
// Build valid_roots in the same order as the reference (left-to-right scan).
376+
valid_roots.clear();
377+
valid_roots.reserve(n / 2);
378+
for (i, node) in nodes.iter().enumerate() {
379+
let PNode::Op { arity, .. } = *node else {
281380
continue;
282381
};
283382
let a = arity as usize;
284383
if a == 0 {
285384
continue;
286385
}
287-
let children = child_ranges(&sizes, i, a);
288-
if children.iter().any(|c| matches!(expr.nodes[c.1], PNode::Op { .. })) {
386+
assert!(a <= D);
387+
if has_op_child(nodes, sizes, i, a) {
289388
valid_roots.push(i);
290389
}
291390
}
292391
if valid_roots.is_empty() {
293392
return false;
294393
}
295394

395+
// RNG draw #1: choose root among valid_roots.
296396
let root_idx = valid_roots[usize_range(rng, 0..valid_roots.len())];
297397
let PNode::Op {
298398
arity: root_arity_u8,
299399
op: op_root,
300-
} = expr.nodes[root_idx]
400+
} = nodes[root_idx]
301401
else {
302402
return false;
303403
};
304404
let root_arity = root_arity_u8 as usize;
305-
if root_arity == 0 {
306-
return false;
307-
}
308-
let root_children = child_ranges(&sizes, root_idx, root_arity);
309405

310-
let pivot_positions: Vec<usize> = root_children
311-
.iter()
312-
.enumerate()
313-
.filter_map(|(j, c)| matches!(expr.nodes[c.1], PNode::Op { .. }).then_some(j))
314-
.collect();
315-
if pivot_positions.is_empty() {
406+
let mut root_children: [(usize, usize); D] = [(0, 0); D];
407+
fill_child_ranges(sizes, root_idx, root_arity, &mut root_children);
408+
409+
// RNG draw #2: choose pivot among op-children of root (same ordering as pivot_positions vec).
410+
let mut n_pivots = 0usize;
411+
for &(_, end) in root_children.iter().take(root_arity) {
412+
if matches!(nodes[end], PNode::Op { .. }) {
413+
n_pivots += 1;
414+
}
415+
}
416+
if n_pivots == 0 {
316417
return false;
317418
}
419+
let mut rem = usize_range(rng, 0..n_pivots);
420+
let mut pivot_pos = 0usize;
421+
let mut pivot_root_idx = 0usize;
422+
for (j, &(_, end)) in root_children.iter().enumerate().take(root_arity) {
423+
if matches!(nodes[end], PNode::Op { .. }) {
424+
if rem == 0 {
425+
pivot_pos = j;
426+
pivot_root_idx = end;
427+
break;
428+
}
429+
rem -= 1;
430+
}
431+
}
318432

319-
let pivot_pos = pivot_positions[usize_range(rng, 0..pivot_positions.len())];
320-
let pivot_root_idx = root_children[pivot_pos].1;
321433
let PNode::Op {
322434
arity: pivot_arity_u8,
323435
op: op_pivot,
324-
} = expr.nodes[pivot_root_idx]
436+
} = nodes[pivot_root_idx]
325437
else {
326-
return false;
438+
panic!("expected op node");
327439
};
328440
let pivot_arity = pivot_arity_u8 as usize;
329-
if pivot_arity == 0 {
330-
return false;
331-
}
332-
let pivot_children = child_ranges(&sizes, pivot_root_idx, pivot_arity);
441+
assert!(pivot_arity > 0 && pivot_arity <= D);
333442

443+
let mut pivot_children: [(usize, usize); D] = [(0, 0); D];
444+
fill_child_ranges(sizes, pivot_root_idx, pivot_arity, &mut pivot_children);
445+
446+
// RNG draw #3: choose grandchild among pivot children.
334447
let grandchild_pos = usize_range(rng, 0..pivot_arity);
335448
let grandchild = pivot_children[grandchild_pos];
336449

337-
let (sub_start, sub_end) = node_utils::subtree_range(&sizes, root_idx);
450+
let root_end = root_idx;
451+
let root_start = root_end + 1 - sizes[root_end];
452+
let subtree_len = sizes[root_end];
338453

339-
// Build the rotated version of the old root, with its `pivot_pos` child replaced by `grandchild`.
340-
let mut rotated_root: Vec<PNode> = Vec::with_capacity(sub_end + 1 - sub_start);
341-
for (j, c) in root_children.iter().enumerate() {
342-
if j == pivot_pos {
343-
rotated_root.extend_from_slice(&expr.nodes[grandchild.0..=grandchild.1]);
344-
} else {
345-
rotated_root.extend_from_slice(&expr.nodes[c.0..=c.1]);
346-
}
454+
// Cheap unary cases: memmove only.
455+
if root_arity == 1 {
456+
let insert_pos = grandchild.1 + 1;
457+
let root_node = nodes[root_end];
458+
nodes.copy_within(insert_pos..root_end, insert_pos + 1);
459+
nodes[insert_pos] = root_node;
460+
return true;
461+
}
462+
if pivot_arity == 1 {
463+
let pivot_node = nodes[pivot_root_idx];
464+
nodes.copy_within((pivot_root_idx + 1)..(root_end + 1), pivot_root_idx);
465+
nodes[root_end] = pivot_node;
466+
return true;
347467
}
348-
rotated_root.push(PNode::Op {
349-
arity: root_arity_u8,
350-
op: op_root,
351-
});
352468

353-
// Build the new subtree rooted at `pivot`, replacing its `grandchild_pos` with `rotated_root`.
354-
let mut new_sub: Vec<PNode> = Vec::with_capacity(sub_end + 1 - sub_start);
355-
for (k, c) in pivot_children.iter().enumerate() {
469+
// General case: rebuild subtree into reusable buffer; copy back (no splice).
470+
buf.clear();
471+
buf.reserve(subtree_len);
472+
473+
for (k, &(start, end)) in pivot_children.iter().enumerate().take(pivot_arity) {
356474
if k == grandchild_pos {
357-
new_sub.extend_from_slice(&rotated_root);
475+
for (j, &(cs, ce)) in root_children.iter().enumerate().take(root_arity) {
476+
let (s, e) = if j == pivot_pos { grandchild } else { (cs, ce) };
477+
buf.extend_from_slice(&nodes[s..(e + 1)]);
478+
}
479+
buf.push(PNode::Op {
480+
arity: root_arity_u8,
481+
op: op_root,
482+
});
358483
} else {
359-
new_sub.extend_from_slice(&expr.nodes[c.0..=c.1]);
484+
buf.extend_from_slice(&nodes[start..(end + 1)]);
360485
}
361486
}
362-
new_sub.push(PNode::Op {
487+
buf.push(PNode::Op {
363488
arity: pivot_arity_u8,
364489
op: op_pivot,
365490
});
366491

367-
expr.nodes.splice(sub_start..=sub_end, new_sub);
492+
debug_assert_eq!(buf.len(), subtree_len);
493+
nodes[root_start..=root_end].copy_from_slice(buf);
368494
true
369495
}
370496

0 commit comments

Comments
 (0)