@@ -348,16 +348,14 @@ fn quality_for_path(
348348 initial_state : & MctsState ,
349349 actions : & [ Action ] ,
350350 replace_initial_action : Action ,
351- max_depth : usize ,
352351) -> f64 {
353352 let mut state = initial_state. clone ( ) ;
354- let mut actions = actions. to_vec ( ) ;
355- actions[ 0 ] = replace_initial_action;
356353
357- for action in actions {
358- state = state. step ( action) ;
354+ state = state. step ( replace_initial_action) ;
355+ for action in & actions[ 1 ..] {
356+ state = state. step ( * action) ;
359357 }
360- state. simulate ( max_depth )
358+ state. quality ( )
361359}
362360
363361/// Binary search for a refined cutting plane. Iteratively try cutting the input
@@ -366,45 +364,51 @@ fn quality_for_path(
366364/// To evaluate the cut, the quality of the entire path with the first cut
367365/// replaced by the left or right hand side from above is simulated. This
368366/// prevents the refinement from being too greedy and reducing future reward.
369- fn refine (
370- initial_state : & MctsState ,
371- initial_path : & [ Action ] ,
372- unit_radius : f64 ,
373- max_depth : usize ,
374- ) -> CanonicalPlane {
367+ fn refine ( initial_state : & MctsState , initial_path : & [ Action ] , unit_radius : f64 ) -> CanonicalPlane {
375368 // Each iteration cuts the search plane in half, so even in the worst case
376369 // (traversing the entire unit interval) this should converge in ~20 steps.
377370 const EPS : f64 = 1e-6 ;
378371
379372 let initial_action = initial_path[ 0 ] ;
380373 let initial_unit_plane = initial_action. unit_plane ;
374+ let initial_q = quality_for_path ( initial_state, initial_path, initial_path[ 0 ] ) ;
381375
382376 let mut lb = initial_unit_plane. bias - unit_radius;
383377 let mut ub = initial_unit_plane. bias + unit_radius;
384378 let mut best_action = initial_action;
385379
386380 // Iterate until convergence.
381+ let mut new_q = initial_q;
387382 while ( ub - lb) > EPS {
388383 let pivot = ( lb + ub) / 2.0 ;
389384
390385 let lhs = Action :: new ( initial_unit_plane. with_bias ( ( lb + pivot) / 2. ) ) ;
391386 let rhs = Action :: new ( initial_unit_plane. with_bias ( ( ub + pivot) / 2. ) ) ;
392387
388+ let lhs_q = quality_for_path ( initial_state, initial_path, lhs) ;
389+ let rhs_q = quality_for_path ( initial_state, initial_path, rhs) ;
390+
393391 // Is left or right better?
394- if quality_for_path ( initial_state, initial_path, lhs, max_depth)
395- > quality_for_path ( initial_state, initial_path, rhs, max_depth)
396- {
392+ if lhs_q > rhs_q {
397393 // Move left
398394 ub = pivot;
399395 best_action = lhs;
396+ new_q = lhs_q;
400397 } else {
401398 // Move right
402399 lb = pivot;
403400 best_action = rhs;
401+ new_q = rhs_q;
404402 }
405403 }
406404
407- best_action. unit_plane
405+ // TODO: Understand why the refined plane could be worse than the initial.
406+ // Falling into a local minimum?
407+ if new_q > initial_q {
408+ best_action. unit_plane
409+ } else {
410+ initial_unit_plane
411+ }
408412}
409413
410414/// An implementation of Monte Carlo Tree Search for the approximate convex
@@ -447,16 +451,20 @@ pub fn run(input_part: &Part, config: &Config) -> Option<CanonicalPlane> {
447451 // Take the discrete best path from MCTS and refine it.
448452 let best_path = mcts. best_path_from_root ( ) ;
449453 if !best_path. is_empty ( ) {
454+ let coarse_plane = best_path[ 0 ] . unit_plane ;
455+ let lb = input_part. bounds . min [ coarse_plane. axis ] ;
456+ let ub = input_part. bounds . max [ coarse_plane. axis ] ;
457+
450458 let refined_plane = refine (
451459 // Start the refinement from the root state, i.e. just the input
452460 // part.
453461 & mcts. nodes [ 0 ] . state ,
454462 & best_path,
455- // TODO: use one node width scaled to mesh bbox
456- 1.0 ,
457- config. mcts_depth ,
463+ // Refinement is only allowed to adjust within a single grid span.
464+ 1.0 / ( config. mcts_grid_nodes + 1 ) as f64 ,
458465 ) ;
459- Some ( refined_plane)
466+
467+ Some ( refined_plane. denormalize ( lb, ub) )
460468 } else {
461469 None
462470 }
0 commit comments