Skip to content

Commit 4226667

Browse files
committed
Fix refinement to consider entire action path.
1 parent b107de2 commit 4226667

File tree

3 files changed

+132
-119
lines changed

3 files changed

+132
-119
lines changed

examples/refine.rs

Lines changed: 0 additions & 33 deletions
This file was deleted.

src/lib.rs

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -68,21 +68,13 @@ fn run_inner(input: Part, config: &Config, progress: &ProgressBar, prev_cost: f6
6868
}
6969

7070
// Further slicing is necessary. Run the tree search to find the slice plane
71-
// which maximizes the future reward. Because the search is discretized,
72-
// this is only a coarse estimate of the best slice plane.
73-
let coarse_plane = mcts::run(input.clone(), config).expect("no action");
74-
75-
// Run the refinement step on the chosen slice. The radius specifies the
76-
// maximum amount of adjustment in either direction. We allow refinement of
77-
// up two one grid width of the MCTS search.
78-
let refine_radius = 1. / (config.num_nodes + 1) as f64;
79-
let refine_plane = mcts::refine(&input, coarse_plane, refine_radius);
80-
81-
let slice_plane = refine_plane.unwrap_or(coarse_plane).denormalize(
82-
input.bounds.min[coarse_plane.axis],
83-
input.bounds.max[coarse_plane.axis],
71+
// which maximizes the future reward and also refine a more precise plane.
72+
let optimal_plane = mcts::run(&input, config).expect("no action");
73+
let abs_optimal_plane = optimal_plane.denormalize(
74+
input.bounds.min[optimal_plane.axis],
75+
input.bounds.max[optimal_plane.axis],
8476
);
85-
let (refined_l, refined_r) = input.slice(slice_plane);
77+
let (lhs, rhs) = input.slice(abs_optimal_plane);
8678

8779
// The input mesh is no longer required. Drop it to save on memory usage.
8880
drop(input);

src/mcts.rs

Lines changed: 126 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,38 @@ impl MctsState {
147147
}
148148
}
149149

150+
/// The default policy chooses the highest reward among splitting the part
151+
/// directly at the center along three axes. The policy is rolled out until
152+
/// the maximum depth is reached.
153+
fn simulate(&self, max_depth: usize) -> f64 {
154+
let default_planes = [
155+
CanonicalPlane { axis: 0, bias: 0.5 },
156+
CanonicalPlane { axis: 1, bias: 0.5 },
157+
CanonicalPlane { axis: 2, bias: 0.5 },
158+
];
159+
160+
let mut current_state = self.clone();
161+
while !current_state.is_terminal(max_depth) {
162+
// PARALLEL: evaluate the axes in parallel.
163+
let (_, state_to_play) = default_planes
164+
.into_par_iter()
165+
.map(|plane| {
166+
let action = Action::new(plane);
167+
168+
let new_state = current_state.step(action);
169+
let new_reward = new_state.reward();
170+
171+
(new_reward, new_state)
172+
})
173+
.max_by(|a, b| a.0.total_cmp(&b.0))
174+
.unwrap();
175+
176+
current_state = state_to_play;
177+
}
178+
179+
current_state.quality()
180+
}
181+
150182
/// The reward is the inverse of the concavity of the worst part, i.e. a
151183
/// smaller concavity gives a higher reward. We aim to maximize the reward.
152184
fn reward(&self) -> f64 {
@@ -225,7 +257,9 @@ impl Mcts {
225257
return v;
226258
}
227259

228-
v = self.best_child(v, c);
260+
v = self
261+
.best_child(v, c)
262+
.expect("selected leaf node must have parent");
229263
}
230264
}
231265

@@ -247,38 +281,6 @@ impl Mcts {
247281
self.nodes[v].children.push(child);
248282
}
249283

250-
/// The default policy chooses the highest reward among splitting the part
251-
/// directly at the center along one of the three axes. The game is played
252-
/// out until the maximum depth is reached.
253-
fn simulate(&self, v: usize, max_depth: usize) -> f64 {
254-
let mut current_state = self.nodes[v].state.clone();
255-
while !current_state.is_terminal(max_depth) {
256-
let default_planes = [
257-
CanonicalPlane { axis: 0, bias: 0.5 },
258-
CanonicalPlane { axis: 1, bias: 0.5 },
259-
CanonicalPlane { axis: 2, bias: 0.5 },
260-
];
261-
262-
// PARALLEL: evaluate the axes in parallel.
263-
let (_, state_to_play) = default_planes
264-
.into_par_iter()
265-
.map(|plane| {
266-
let action = Action::new(plane);
267-
268-
let new_state = current_state.step(action);
269-
let new_reward = new_state.reward();
270-
271-
(new_reward, new_state)
272-
})
273-
.max_by(|a, b| a.0.total_cmp(&b.0))
274-
.unwrap();
275-
276-
current_state = state_to_play;
277-
}
278-
279-
current_state.quality()
280-
}
281-
282284
/// Upper confidence estimate of the given node's reward.
283285
fn ucb(&self, v: usize, c: f64) -> f64 {
284286
if self.nodes[v].n == 0 {
@@ -294,20 +296,15 @@ impl Mcts {
294296
}
295297

296298
/// The next child to explore, based on the tradeoff of exploration and
297-
/// exploitation.
298-
fn best_child(&self, v: usize, c: f64) -> usize {
299+
/// exploitation. Returns None if there are no children of v.
300+
fn best_child(&self, v: usize, c: f64) -> Option<usize> {
299301
let node = &self.nodes[v];
300-
assert!(!node.children.is_empty());
301-
302-
node.children
303-
.iter()
304-
.copied()
305-
.max_by(|&a, &b| {
306-
let ucb_a = self.ucb(a, c);
307-
let ucb_b = self.ucb(b, c);
308-
ucb_a.total_cmp(&ucb_b)
309-
})
310-
.unwrap()
302+
303+
node.children.iter().copied().max_by(|&a, &b| {
304+
let ucb_a = self.ucb(a, c);
305+
let ucb_b = self.ucb(b, c);
306+
ucb_a.total_cmp(&ucb_b)
307+
})
311308
}
312309

313310
/// Propagate rewards at the leaf nodes back up through the tree.
@@ -324,62 +321,105 @@ impl Mcts {
324321
}
325322
}
326323
}
324+
325+
/// Returns the action path from the root to the highest reward terminal
326+
/// node.
327+
fn best_path_from_root(&self) -> Vec<Action> {
328+
let mut best_path = vec![];
329+
let mut v = 0;
330+
while let Some(child) = self.best_child(v, 0.0) {
331+
if let Some(action) = self.nodes[child].action {
332+
best_path.push(action);
333+
}
334+
335+
v = child;
336+
}
337+
338+
best_path
339+
}
340+
}
341+
342+
/// Compute the quality for a path starting at initial_state, with the first
343+
/// action as replace_initial_action, and the remaining actions as actions[1..].
344+
///
345+
/// Used for refinement to determine if a first step replacement results in a
346+
/// high quality path.
347+
fn quality_for_path(
348+
initial_state: &MctsState,
349+
actions: &[Action],
350+
replace_initial_action: Action,
351+
max_depth: usize,
352+
) -> f64 {
353+
let mut state = initial_state.clone();
354+
let mut actions = actions.to_vec();
355+
actions[0] = replace_initial_action;
356+
357+
for action in actions {
358+
state = state.step(action);
359+
}
360+
state.simulate(max_depth)
327361
}
328362

329363
/// Binary search for a refined cutting plane. Iteratively try cutting the input
330-
/// to the left and to the right of the initial plane. Whichever cut side
331-
/// results in a higher reward is recursively refined.
332-
pub fn refine(
333-
input_part: &Part,
334-
initial_unit_plane: CanonicalPlane,
364+
/// to the left and to the right of the initial plane.
365+
///
366+
/// To evaluate the cut, the quality of the entire path with the first cut
367+
/// replaced by the left or right hand side from above is simulated. This
368+
/// prevents the refinement from being too greedy and reducing future reward.
369+
fn refine(
370+
initial_state: &MctsState,
371+
initial_path: &[Action],
335372
unit_radius: f64,
336-
) -> Option<CanonicalPlane> {
337-
const EPS: f64 = 1e-5;
373+
max_depth: usize,
374+
) -> CanonicalPlane {
375+
// Each iteration cuts the search plane in half, so even in the worst case
376+
// (traversing the entire unit interval) this should converge in ~20 steps.
377+
const EPS: f64 = 1e-6;
338378

339-
let state = MctsState {
340-
parts: vec![input_part.clone()],
341-
parent_rewards: vec![],
342-
depth: 0,
343-
};
379+
let initial_action = initial_path[0];
380+
let initial_unit_plane = initial_action.unit_plane;
344381

345382
let mut lb = initial_unit_plane.bias - unit_radius;
346383
let mut ub = initial_unit_plane.bias + unit_radius;
347-
let mut best_action = None;
384+
let mut best_action = initial_action;
348385

349-
// Each iteration cuts the search plane in half, so even in the worst case
350-
// (traversing the entire unit interval) this should converge in ~20 steps.
386+
// Iterate until convergence.
351387
while (ub - lb) > EPS {
352388
let pivot = (lb + ub) / 2.0;
353389

354390
let lhs = Action::new(initial_unit_plane.with_bias((lb + pivot) / 2.));
355391
let rhs = Action::new(initial_unit_plane.with_bias((ub + pivot) / 2.));
356392

357-
if state.step(lhs).reward() > state.step(rhs).reward() {
393+
// Is left or right better?
394+
if quality_for_path(initial_state, initial_path, lhs, max_depth)
395+
> quality_for_path(initial_state, initial_path, rhs, max_depth)
396+
{
358397
// Move left
359398
ub = pivot;
360-
best_action = Some(lhs);
399+
best_action = lhs;
361400
} else {
362401
// Move right
363402
lb = pivot;
364-
best_action = Some(rhs);
403+
best_action = rhs;
365404
}
366405
}
367406

368-
best_action.map(|a| a.unit_plane)
407+
best_action.unit_plane
369408
}
370409

371410
/// An implementation of Monte Carlo Tree Search for the approximate convex
372411
/// decomposition via mesh slicing problem.
373412
///
374413
/// A run of the tree search returns the slice with the highest estimated
375414
/// probability to lead to a large reward when followed by more slices.
376-
pub fn run(input_part: Part, config: &Config) -> Option<CanonicalPlane> {
415+
pub fn run(input_part: &Part, config: &Config) -> Option<CanonicalPlane> {
377416
// A deterministic random number generator.
378417
let mut rng = ChaCha8Rng::seed_from_u64(config.random_seed);
379418

419+
// The root MCTS node contains just the input part, unmodified.
380420
let root_node = MctsNode::new(
381421
MctsState {
382-
parts: vec![input_part],
422+
parts: vec![input_part.clone()],
383423
parent_rewards: vec![],
384424
depth: 0,
385425
},
@@ -388,6 +428,8 @@ pub fn run(input_part: Part, config: &Config) -> Option<CanonicalPlane> {
388428
None,
389429
);
390430

431+
// Run the MCTS algorithm for the specified compute time to compute a
432+
// probabilistic best path.
391433
let mut mcts = Mcts::new(root_node);
392434
for _ in 0..config.iterations {
393435
let mut v = mcts.select(config.exploration_param);
@@ -398,12 +440,24 @@ pub fn run(input_part: Part, config: &Config) -> Option<CanonicalPlane> {
398440
v = *children.choose(&mut rng).unwrap();
399441
}
400442

401-
let reward = mcts.simulate(v, config.max_depth);
443+
let reward = mcts.nodes[v].state.simulate(config.max_depth);
402444
mcts.backprop(v, reward);
403445
}
404446

405-
// For the final result, we only care about the best node. We never want to
406-
// return an exploratory node. Set the exploration parameter to zero.
407-
let best_node = &mcts.nodes[mcts.best_child(0, 0.0)];
408-
best_node.action.map(|a| a.unit_plane)
447+
// Take the discrete best path from MCTS and refine it.
448+
let best_path = mcts.best_path_from_root();
449+
if !best_path.is_empty() {
450+
let refined_plane = refine(
451+
// Start the refinement from the root state, i.e. just the input
452+
// part.
453+
&mcts.nodes[0].state,
454+
&best_path,
455+
// TODO: use one node width scaled to mesh bbox
456+
1.0,
457+
config.max_depth,
458+
);
459+
Some(refined_plane)
460+
} else {
461+
None
462+
}
409463
}

0 commit comments

Comments
 (0)