Skip to content

Commit 66f5dd0

Browse files
committed
Unify parameter names and improve docs.
1 parent a2be670 commit 66f5dd0

File tree

5 files changed

+57
-51
lines changed

5 files changed

+57
-51
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ kiddo = "5.2.2"
1919
mimalloc = { version = "0.1.48", features = ["v3"] }
2020
nalgebra = { version = "0.34.1", features = ["rand"] }
2121
parry3d-f64 = "0.25"
22-
pyo3 = { version = "0.27.0", features = ["abi3-py39"] }
22+
pyo3 = { version = "0.27.0", features = ["abi3-py39", "experimental-inspect"] }
2323
rand = { version = "0.9.2", default-features = false, features = ["std"] }
2424
rand_chacha = "0.9.0"
2525
rayon = "1.11.0"

python/miniacd/cli.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -46,26 +46,26 @@ def random_rgb() -> tuple[int, int, int]:
4646
type=click.FloatRange(min=0.0),
4747
help="Concavity threshold at which mesh parts will be accepted",
4848
)
49-
@click.option(
50-
"--mcts-depth",
51-
default=3,
52-
type=click.IntRange(min=0),
53-
help="Monte Carlo tree search search depth",
54-
)
5549
@click.option(
5650
"--mcts-iterations",
5751
default=150,
5852
type=click.IntRange(min=0),
5953
help="Monte Carlo tree search search iterations",
6054
)
6155
@click.option(
62-
"--mcts-nodes",
56+
"--mcts-depth",
57+
default=3,
58+
type=click.IntRange(min=1),
59+
help="Monte Carlo tree search search depth",
60+
)
61+
@click.option(
62+
"--mcts-grid-nodes",
6363
default=20,
64-
type=click.IntRange(min=0),
64+
type=click.IntRange(min=1),
6565
help="The discretization size in the Monte Carlo tree search",
6666
)
6767
@click.option(
68-
"--seed",
68+
"--mcts-random-seed",
6969
default=0,
7070
type=click.IntRange(min=0),
7171
help="Random generator seed for deterministic output",
@@ -78,8 +78,8 @@ def main(
7878
threshold: float,
7979
mcts_depth: int,
8080
mcts_iterations: int,
81-
mcts_nodes: int,
82-
seed: int,
81+
mcts_grid_nodes: int,
82+
mcts_random_seed: int,
8383
):
8484
"""
8585
miniacd decomposes watertight 3D meshes into convex components
@@ -91,10 +91,10 @@ def main(
9191
parts = miniacd.run(
9292
mesh,
9393
threshold=threshold,
94-
max_depth=mcts_depth,
95-
iterations=mcts_iterations,
96-
num_nodes=mcts_nodes,
97-
random_seed=seed,
94+
mcts_depth=mcts_depth,
95+
mcts_iterations=mcts_iterations,
96+
mcts_grid_nodes=mcts_grid_nodes,
97+
mcts_random_seed=mcts_random_seed,
9898
print=True,
9999
)
100100

src/lib.rs

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,19 @@ pub struct Config {
2727
/// The minimum acceptable concavity metric for an individual part to be
2828
/// accepted.
2929
pub threshold: f64,
30-
/// The number of samples taken from the MCTS before choosing a move.
31-
pub iterations: usize,
30+
/// The number of iterations taken for each MCTS step which chooses a single
31+
/// slicing plane.
32+
pub mcts_iterations: usize,
3233
/// The depth of the MCTS, i.e. the number of lookahead moves when deciding
3334
/// on the next move.
34-
pub max_depth: usize,
35-
pub exploration_param: f64,
35+
pub mcts_depth: usize,
36+
/// The exploration parameter for computing the upper confidence bound
37+
/// (UCB).
38+
pub mcts_exploration: f64,
3639
/// The number of discrete slices taken per axis at each node in the MCTS.
37-
pub num_nodes: usize,
38-
pub random_seed: u64,
40+
pub mcts_grid_nodes: usize,
41+
/// A seed for the deterministic RNG.
42+
pub mcts_random_seed: u64,
3943
/// Print the progress bar? Enable for human users, disable for tests etc.
4044
pub print: bool,
4145
}
@@ -44,11 +48,11 @@ impl Default for Config {
4448
fn default() -> Self {
4549
Self {
4650
threshold: 0.1,
47-
iterations: 150,
48-
max_depth: 3,
49-
exploration_param: f64::sqrt(2.0),
50-
num_nodes: 20,
51-
random_seed: 0,
51+
mcts_iterations: 150,
52+
mcts_depth: 3,
53+
mcts_exploration: f64::sqrt(2.0),
54+
mcts_grid_nodes: 20,
55+
mcts_random_seed: 0,
5256
print: false,
5357
}
5458
}

src/mcts.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ fn refine(
414414
/// probability to lead to a large reward when followed by more slices.
415415
pub fn run(input_part: &Part, config: &Config) -> Option<CanonicalPlane> {
416416
// A deterministic random number generator.
417-
let mut rng = ChaCha8Rng::seed_from_u64(config.random_seed);
417+
let mut rng = ChaCha8Rng::seed_from_u64(config.mcts_random_seed);
418418

419419
// The root MCTS node contains just the input part, unmodified.
420420
let root_node = MctsNode::new(
@@ -423,24 +423,24 @@ pub fn run(input_part: &Part, config: &Config) -> Option<CanonicalPlane> {
423423
parent_rewards: vec![],
424424
depth: 0,
425425
},
426-
all_actions(config.num_nodes, &mut rng),
426+
all_actions(config.mcts_grid_nodes, &mut rng),
427427
None,
428428
None,
429429
);
430430

431431
// Run the MCTS algorithm for the specified compute time to compute a
432432
// probabilistic best path.
433433
let mut mcts = Mcts::new(root_node);
434-
for _ in 0..config.iterations {
435-
let mut v = mcts.select(config.exploration_param);
434+
for _ in 0..config.mcts_iterations {
435+
let mut v = mcts.select(config.mcts_exploration);
436436

437-
if !mcts.nodes[v].is_terminal(config.max_depth) {
438-
mcts.expand(v, config.num_nodes, &mut rng);
437+
if !mcts.nodes[v].is_terminal(config.mcts_depth) {
438+
mcts.expand(v, config.mcts_grid_nodes, &mut rng);
439439
let children = &mcts.nodes[v].children;
440440
v = *children.choose(&mut rng).unwrap();
441441
}
442442

443-
let reward = mcts.nodes[v].state.simulate(config.max_depth);
443+
let reward = mcts.nodes[v].state.simulate(config.mcts_depth);
444444
mcts.backprop(v, reward);
445445
}
446446

@@ -454,7 +454,7 @@ pub fn run(input_part: &Part, config: &Config) -> Option<CanonicalPlane> {
454454
&best_path,
455455
// TODO: use one node width scaled to mesh bbox
456456
1.0,
457-
config.max_depth,
457+
config.mcts_depth,
458458
);
459459
Some(refined_plane)
460460
} else {

src/py.rs

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,29 +37,31 @@ mod pyminiacd {
3737
}
3838

3939
#[pyfunction]
40-
#[pyo3(signature=(mesh,
41-
threshold=0.1,
42-
iterations=150,
43-
max_depth=3,
44-
num_nodes=20,
45-
random_seed=42,
46-
print=true))]
40+
#[pyo3(signature=(
41+
mesh: "PyMesh",
42+
threshold: "float" = 0.1,
43+
mcts_iterations: "int" = 150,
44+
mcts_depth: "int" = 3,
45+
mcts_grid_nodes: "int" = 20,
46+
mcts_random_seed: "int" = 42,
47+
print: "bool" = true
48+
) -> "list[PyMesh]")]
4749
fn run(
4850
mesh: &PyMesh,
4951
threshold: f64,
50-
iterations: usize,
51-
max_depth: usize,
52-
num_nodes: usize,
53-
random_seed: u64,
52+
mcts_iterations: usize,
53+
mcts_depth: usize,
54+
mcts_grid_nodes: usize,
55+
mcts_random_seed: u64,
5456
print: bool,
5557
) -> Vec<PyMesh> {
5658
let config = Config {
5759
threshold,
58-
iterations,
59-
max_depth,
60-
exploration_param: f64::sqrt(2.0),
61-
num_nodes,
62-
random_seed,
60+
mcts_iterations,
61+
mcts_depth,
62+
mcts_exploration: f64::sqrt(2.0),
63+
mcts_grid_nodes,
64+
mcts_random_seed,
6365
print,
6466
};
6567

0 commit comments

Comments
 (0)