Skip to content

Commit bf40c32

Browse files
authored
Merge pull request #342 from adamantivm/jac/eval_actions
Initial skeleton and binary for self-play rust
2 parents d465c07 + 03bc75e commit bf40c32

File tree

6 files changed

+827
-6
lines changed

6 files changed

+827
-6
lines changed

deep_quoridor/agents.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1-
Whenever you commit to git, create the commit message starting with "vibe: " and then a one line summary of the changes.
1+
Whenever you commit to git, create the commit message starting with "vibe: " and then a one line summary of the changes.
2+
3+
Whenever you change rust files, before commit, make sure to run cargo fmt to format all files and then check formatting, build and run before committing.

deep_quoridor/rust/Cargo.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ name = "create_policy_db"
1616
path = "src/bin/create_policy_db.rs"
1717
required-features = ["binary"]
1818

19+
[[bin]]
20+
name = "selfplay"
21+
path = "src/bin/selfplay.rs"
22+
required-features = ["binary"]
23+
1924
[dependencies]
2025
pyo3 = { version = "0.22", features = ["extension-module"], optional = true }
2126
numpy = { version = "0.22", optional = true }
@@ -25,11 +30,13 @@ rayon = "1.10"
2530
rusqlite = { version = "0.32", features = ["bundled"] }
2631
serde = { version = "1", features = ["derive"] }
2732
clap = { version = "4.5", features = ["derive"], optional = true }
33+
ort = { version = "2.0.0-rc.11", optional = true }
34+
anyhow = { version = "1", optional = true }
2835

2936
[features]
3037
default = ["python"]
3138
python = ["pyo3", "numpy"]
32-
binary = ["clap"]
39+
binary = ["clap", "ort", "anyhow"]
3340

3441
[profile.release]
3542
# Enable optimizations for better performance
Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
#!/usr/bin/env rust
2+
//! Self-play executable using ONNX inference for Quoridor.
3+
//!
4+
//! This binary loads a trained ONNX model and uses it to evaluate actions
5+
//! on a Quoridor game board, applying the selected action and displaying the result.
6+
7+
use anyhow::{Context, Result};
8+
use ndarray::Array1;
9+
use ort::session::Session;
10+
11+
use quoridor_rs::actions::{get_valid_move_actions, get_valid_wall_actions};
12+
use quoridor_rs::game_state::{apply_action, create_initial_state};
13+
use quoridor_rs::grid_helpers::grid_game_state_to_resnet_input;
14+
15+
/// Convert 4D array to 1D vector for ONNX input
16+
fn array4d_to_vec(arr: &ndarray::Array4<f32>) -> Vec<f32> {
17+
arr.iter().copied().collect()
18+
}
19+
20+
/// Compute softmax values for policy logits
21+
///
22+
/// Note: While ORT's OrtOwnedTensor has a softmax method, using it would require
23+
/// copying the logits from the borrowed slice (&[f32]) returned by try_extract_tensor
24+
/// into an owned OrtOwnedTensor structure. This data copy would be inefficient and
25+
/// defeat the purpose of using a pre-built library function, so we implement softmax
26+
/// directly on the borrowed slice instead.
27+
fn softmax(logits: &[f32]) -> Vec<f32> {
28+
let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
29+
let exp_values: Vec<f32> = logits.iter().map(|&x| (x - max).exp()).collect();
30+
let sum: f32 = exp_values.iter().sum();
31+
exp_values.iter().map(|&x| x / sum).collect()
32+
}
33+
34+
/// Evaluate an action using the ONNX model
35+
///
36+
/// Returns the chosen action as [row, col, action_type]
37+
fn evaluate_action(
38+
session: &mut Session,
39+
grid: &ndarray::ArrayView2<i8>,
40+
player_positions: &ndarray::ArrayView2<i32>,
41+
walls_remaining: &ndarray::ArrayView1<i32>,
42+
goal_rows: &ndarray::ArrayView1<i32>,
43+
current_player: i32,
44+
) -> Result<Array1<i32>> {
45+
// Convert game state to ResNet input format
46+
let resnet_input_tensor =
47+
grid_game_state_to_resnet_input(grid, player_positions, walls_remaining, current_player);
48+
49+
// Convert to ONNX input format
50+
let shape = resnet_input_tensor.shape().to_vec();
51+
let data = array4d_to_vec(&resnet_input_tensor);
52+
let input_value = ort::value::Value::from_array((shape.as_slice(), data))
53+
.context("Failed to create ResNet input value")?;
54+
55+
// Run inference
56+
let outputs = session
57+
.run(ort::inputs!["input" => input_value])
58+
.context("Failed to run ResNet inference")?;
59+
60+
// Extract policy logits
61+
let policy_logits = outputs["policy_logits"]
62+
.try_extract_tensor::<f32>()
63+
.context("Failed to extract policy logits")?;
64+
65+
// Convert to probabilities
66+
let policy_probs = softmax(policy_logits.1);
67+
68+
// Get all valid actions
69+
let move_actions = get_valid_move_actions(grid, player_positions, current_player);
70+
let wall_actions = get_valid_wall_actions(
71+
grid,
72+
player_positions,
73+
walls_remaining,
74+
goal_rows,
75+
current_player,
76+
);
77+
78+
// Calculate action sizes
79+
let grid_width = grid.ncols() as i32;
80+
let board_size = (grid_width - 4) / 2 + 1;
81+
let num_move_actions = board_size * board_size;
82+
let wall_size = board_size - 1;
83+
let num_wall_actions = wall_size * wall_size;
84+
85+
// Find best valid action
86+
let mut best_action_idx = 0;
87+
let mut best_prob = f32::NEG_INFINITY;
88+
89+
// Check move actions
90+
for i in 0..move_actions.nrows() {
91+
let row = move_actions[[i, 0]];
92+
let col = move_actions[[i, 1]];
93+
let action_idx = (row * board_size + col) as usize;
94+
95+
if action_idx < policy_probs.len() && policy_probs[action_idx] > best_prob {
96+
best_prob = policy_probs[action_idx];
97+
best_action_idx = i;
98+
}
99+
}
100+
101+
// Check wall actions
102+
for i in 0..wall_actions.nrows() {
103+
let row = wall_actions[[i, 0]];
104+
let col = wall_actions[[i, 1]];
105+
let action_type = wall_actions[[i, 2]];
106+
107+
// Calculate action index
108+
let wall_base_idx = if action_type == 1 {
109+
// Horizontal wall
110+
(num_move_actions + row * wall_size + col) as usize
111+
} else {
112+
// Vertical wall
113+
(num_move_actions + num_wall_actions + row * wall_size + col) as usize
114+
};
115+
116+
if wall_base_idx < policy_probs.len() && policy_probs[wall_base_idx] > best_prob {
117+
best_prob = policy_probs[wall_base_idx];
118+
best_action_idx = move_actions.nrows() + i;
119+
}
120+
}
121+
122+
// Return the chosen action
123+
if best_action_idx < move_actions.nrows() {
124+
Ok(Array1::from_vec(vec![
125+
move_actions[[best_action_idx, 0]],
126+
move_actions[[best_action_idx, 1]],
127+
move_actions[[best_action_idx, 2]],
128+
]))
129+
} else {
130+
let wall_idx = best_action_idx - move_actions.nrows();
131+
Ok(Array1::from_vec(vec![
132+
wall_actions[[wall_idx, 0]],
133+
wall_actions[[wall_idx, 1]],
134+
wall_actions[[wall_idx, 2]],
135+
]))
136+
}
137+
}
138+
139+
/// Print the game board
140+
fn print_board(
141+
grid: &ndarray::ArrayView2<i8>,
142+
player_positions: &ndarray::ArrayView2<i32>,
143+
walls_remaining: &ndarray::ArrayView1<i32>,
144+
) {
145+
let grid_width = grid.ncols() as i32;
146+
let board_size = (grid_width - 4) / 2 + 1;
147+
148+
println!("\n=== Game Board ({}x{}) ===", board_size, board_size);
149+
println!(
150+
"Player 0 (P0): Position ({}, {}), Walls remaining: {}",
151+
player_positions[[0, 0]],
152+
player_positions[[0, 1]],
153+
walls_remaining[0]
154+
);
155+
println!(
156+
"Player 1 (P1): Position ({}, {}), Walls remaining: {}",
157+
player_positions[[1, 0]],
158+
player_positions[[1, 1]],
159+
walls_remaining[1]
160+
);
161+
println!();
162+
163+
// Print the board (showing only player positions and walls)
164+
for row in 0..board_size {
165+
for col in 0..board_size {
166+
let grid_row = (row * 2 + 2) as usize;
167+
let grid_col = (col * 2 + 2) as usize;
168+
169+
let cell = grid[[grid_row, grid_col]];
170+
if cell == 0 {
171+
print!("P0 ");
172+
} else if cell == 1 {
173+
print!("P1 ");
174+
} else {
175+
print!(" . ");
176+
}
177+
}
178+
println!();
179+
}
180+
println!();
181+
}
182+
183+
fn main() -> Result<()> {
184+
println!("=== Quoridor Self-Play with ONNX Inference ===\n");
185+
186+
// Hardcoded model path (relative to rust directory)
187+
let model_path = "../../experiments/onnx/B5W3_resnet_sample.onnx";
188+
189+
println!("Loading ONNX model from: {}", model_path);
190+
191+
// Load ONNX model
192+
let mut session = Session::builder()
193+
.context("Failed to create session builder")?
194+
.commit_from_file(model_path)
195+
.context("Failed to load ONNX model")?;
196+
197+
println!("✓ Model loaded successfully!\n");
198+
199+
// Game configuration (must match the trained model)
200+
let board_size = 5;
201+
let max_walls = 3;
202+
203+
println!(
204+
"Game configuration: {}x{} board, {} walls per player\n",
205+
board_size, board_size, max_walls
206+
);
207+
208+
// Create initial game state
209+
let (mut grid, mut player_positions, mut walls_remaining, goal_rows) =
210+
create_initial_state(board_size, max_walls);
211+
let current_player = 0;
212+
213+
println!("Initial board:");
214+
print_board(
215+
&grid.view(),
216+
&player_positions.view(),
217+
&walls_remaining.view(),
218+
);
219+
220+
// Evaluate action using ONNX model
221+
println!("Evaluating action for Player {}...", current_player);
222+
let action = evaluate_action(
223+
&mut session,
224+
&grid.view(),
225+
&player_positions.view(),
226+
&walls_remaining.view(),
227+
&goal_rows.view(),
228+
current_player,
229+
)?;
230+
231+
println!(
232+
"Selected action: row={}, col={}, type={}",
233+
action[0], action[1], action[2]
234+
);
235+
236+
// Apply the action
237+
apply_action(
238+
&mut grid.view_mut(),
239+
&mut player_positions.view_mut(),
240+
&mut walls_remaining.view_mut(),
241+
current_player,
242+
&action.view(),
243+
);
244+
245+
println!("\nBoard after applying action:");
246+
print_board(
247+
&grid.view(),
248+
&player_positions.view(),
249+
&walls_remaining.view(),
250+
);
251+
252+
println!("✓ Self-play demonstration completed successfully!");
253+
254+
Ok(())
255+
}

deep_quoridor/rust/src/game_state.rs

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,71 @@
11
#![allow(dead_code)]
22

3-
use ndarray::{ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2};
3+
use ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2};
44

55
use crate::actions::{ACTION_MOVE, ACTION_WALL_HORIZONTAL, ACTION_WALL_VERTICAL};
66
use crate::grid::{set_wall_cells, CELL_FREE, CELL_WALL};
77

8+
/// Initialize the initial game state for a Quoridor board
9+
///
10+
/// Creates the initial game state for a Quoridor board with:
11+
/// - A grid of size (board_size * 2 + 3) x (board_size * 2 + 3)
12+
/// - Border walls around the perimeter
13+
/// - Players positioned at top and bottom center
14+
/// - Specified number of walls for each player
15+
///
16+
/// # Arguments
17+
/// * `board_size` - The size of the board (e.g., 5 for a 5x5 board, 9 for standard Quoridor)
18+
/// * `max_walls` - Number of walls each player starts with
19+
///
20+
/// # Returns
21+
/// A tuple containing:
22+
/// * `grid` - The game grid with border walls and player positions
23+
/// * `player_positions` - Array of player positions [player_id, [row, col]]
24+
/// * `walls_remaining` - Array of walls remaining for each player
25+
/// * `goal_rows` - Array of goal rows for each player
26+
pub fn create_initial_state(
27+
board_size: i32,
28+
max_walls: i32,
29+
) -> (Array2<i8>, Array2<i32>, Array1<i32>, Array1<i32>) {
30+
let grid_size = (board_size * 2 + 3) as usize;
31+
32+
let mut grid = Array2::<i8>::from_elem((grid_size, grid_size), CELL_FREE);
33+
34+
// Add border walls
35+
for i in 0..2 {
36+
for j in 0..grid_size {
37+
grid[[i, j]] = CELL_WALL;
38+
grid[[grid_size - 1 - i, j]] = CELL_WALL;
39+
grid[[j, i]] = CELL_WALL;
40+
grid[[j, grid_size - 1 - i]] = CELL_WALL;
41+
}
42+
}
43+
44+
let mut player_positions = Array2::<i32>::zeros((2, 2));
45+
let center_col = board_size / 2;
46+
47+
// Player 0 starts at top center
48+
player_positions[[0, 0]] = 0;
49+
player_positions[[0, 1]] = center_col;
50+
// Player 1 starts at bottom center
51+
player_positions[[1, 0]] = board_size - 1;
52+
player_positions[[1, 1]] = center_col;
53+
54+
// Place players on grid (grid coords are board_coords * 2 + 2)
55+
let p0_grid_row = (player_positions[[0, 0]] * 2 + 2) as usize;
56+
let p0_grid_col = (player_positions[[0, 1]] * 2 + 2) as usize;
57+
let p1_grid_row = (player_positions[[1, 0]] * 2 + 2) as usize;
58+
let p1_grid_col = (player_positions[[1, 1]] * 2 + 2) as usize;
59+
60+
grid[[p0_grid_row, p0_grid_col]] = 0;
61+
grid[[p1_grid_row, p1_grid_col]] = 1;
62+
63+
let walls_remaining = Array1::from(vec![max_walls, max_walls]);
64+
let goal_rows = Array1::from(vec![board_size - 1, 0]); // Player 0 wants bottom, Player 1 wants top
65+
66+
(grid, player_positions, walls_remaining, goal_rows)
67+
}
68+
869
/// Check if a player has won by reaching their goal row.
970
///
1071
/// This is a direct port of check_win from qgrid.py.

0 commit comments

Comments
 (0)