Skip to content

Commit f965d71

Browse files
author
Julian Cerruti
committed
vibe: add --trace flag for step-by-step game tracing
1 parent 9b71a9d commit f965d71

File tree

8 files changed

+191
-28
lines changed

8 files changed

+191
-28
lines changed

deep_quoridor/rust/src/actions.rs

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,10 @@ mod tests {
407407
for idx in 0..total {
408408
let action = action_index_to_action(board_size, idx);
409409
let recovered = action_to_index(board_size, &action);
410-
assert_eq!(idx, recovered, "board_size={board_size}, idx={idx}, action={action:?}");
410+
assert_eq!(
411+
idx, recovered,
412+
"board_size={board_size}, idx={idx}, action={action:?}"
413+
);
411414
}
412415
}
413416
}
@@ -487,20 +490,37 @@ mod tests {
487490
// Count valid actions from full mask
488491
let mask_count: usize = full_mask.iter().filter(|&&x| x).count();
489492
let expected_count = move_actions.nrows() + wall_actions.nrows();
490-
assert_eq!(mask_count, expected_count, "mask has {mask_count} true entries, expected {expected_count}");
493+
assert_eq!(
494+
mask_count, expected_count,
495+
"mask has {mask_count} true entries, expected {expected_count}"
496+
);
491497

492498
// Verify each move action is in the mask
493499
for i in 0..move_actions.nrows() {
494-
let action = [move_actions[[i, 0]], move_actions[[i, 1]], move_actions[[i, 2]]];
500+
let action = [
501+
move_actions[[i, 0]],
502+
move_actions[[i, 1]],
503+
move_actions[[i, 2]],
504+
];
495505
let idx = action_to_index(board_size, &action);
496-
assert!(full_mask[idx], "Move action {action:?} at index {idx} not in mask");
506+
assert!(
507+
full_mask[idx],
508+
"Move action {action:?} at index {idx} not in mask"
509+
);
497510
}
498511

499512
// Verify each wall action is in the mask
500513
for i in 0..wall_actions.nrows() {
501-
let action = [wall_actions[[i, 0]], wall_actions[[i, 1]], wall_actions[[i, 2]]];
514+
let action = [
515+
wall_actions[[i, 0]],
516+
wall_actions[[i, 1]],
517+
wall_actions[[i, 2]],
518+
];
502519
let idx = action_to_index(board_size, &action);
503-
assert!(full_mask[idx], "Wall action {action:?} at index {idx} not in mask");
520+
assert!(
521+
full_mask[idx],
522+
"Wall action {action:?} at index {idx} not in mask"
523+
);
504524
}
505525
}
506526
}

deep_quoridor/rust/src/agents/onnx_agent.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,12 @@ impl ActionSelector for OnnxAgent {
4444
action_mask: &[bool],
4545
) -> Result<(usize, Vec<f32>)> {
4646
// Build ResNet input tensor
47-
let resnet_input =
48-
grid_game_state_to_resnet_input(grid, player_positions, walls_remaining, current_player);
47+
let resnet_input = grid_game_state_to_resnet_input(
48+
grid,
49+
player_positions,
50+
walls_remaining,
51+
current_player,
52+
);
4953

5054
// Convert to flat vec for ORT
5155
let shape = resnet_input.shape().to_vec();

deep_quoridor/rust/src/agents/random_agent.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,11 @@ mod tests {
8282
&mask,
8383
)
8484
.unwrap();
85-
assert!(mask[idx], "RandomAgent picked an invalid action index {}", idx);
85+
assert!(
86+
mask[idx],
87+
"RandomAgent picked an invalid action index {}",
88+
idx
89+
);
8690
}
8791
}
8892

deep_quoridor/rust/src/bin/selfplay.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ struct Cli {
5151
/// Use "random" for a random agent.
5252
#[arg(long)]
5353
p2: Option<String>,
54+
55+
/// Print a step-by-step trace of each game (whose turn, action, board).
56+
#[arg(long, default_value = "false")]
57+
trace: bool,
5458
}
5559

5660
fn main() -> Result<()> {
@@ -77,7 +81,10 @@ fn main() -> Result<()> {
7781
random_p2 = Some(RandomAgent::new());
7882
}
7983
Some(other) => {
80-
anyhow::bail!("Unknown --p2 agent type: '{}'. Valid options: random", other);
84+
anyhow::bail!(
85+
"Unknown --p2 agent type: '{}'. Valid options: random",
86+
other
87+
);
8188
}
8289
None => {
8390
onnx_p2 = Some(OnnxAgent::new(&cli.model_path)?);
@@ -106,6 +113,7 @@ fn main() -> Result<()> {
106113
q.board_size,
107114
q.max_walls,
108115
q.max_steps as i32,
116+
cli.trace,
109117
)?;
110118

111119
// Update stats
@@ -149,6 +157,9 @@ fn main() -> Result<()> {
149157
}
150158
}
151159

152-
println!("Done. {} games written to {}", cli.num_games, cli.output_dir);
160+
println!(
161+
"Done. {} games written to {}",
162+
cli.num_games, cli.output_dir
163+
);
153164
Ok(())
154165
}

deep_quoridor/rust/src/game_runner.rs

Lines changed: 134 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,110 @@
77
88
use ndarray::Array3;
99

10-
use crate::actions::{action_index_to_action, compute_full_action_mask, policy_size};
10+
use crate::actions::{
11+
action_index_to_action, compute_full_action_mask, policy_size, ACTION_MOVE,
12+
ACTION_WALL_HORIZONTAL, ACTION_WALL_VERTICAL,
13+
};
1114
use crate::agents::ActionSelector;
1215
use crate::game_state::{apply_action, check_win, create_initial_state};
16+
use crate::grid::CELL_WALL;
1317
use crate::grid_helpers::grid_game_state_to_resnet_input;
1418
use crate::rotation::{
15-
rotate_action_coords, rotate_goal_rows,
16-
rotate_grid_180, rotate_player_positions,
19+
rotate_action_coords, rotate_goal_rows, rotate_grid_180, rotate_player_positions,
1720
};
1821

22+
/// Format an action triple as a human-readable string.
23+
fn format_action(_board_size: i32, row: i32, col: i32, action_type: i32) -> String {
24+
match action_type {
25+
ACTION_MOVE => format!("Move to ({}, {})", row, col),
26+
ACTION_WALL_HORIZONTAL => format!("Place horizontal wall at ({}, {})", row, col),
27+
ACTION_WALL_VERTICAL => format!("Place vertical wall at ({}, {})", row, col),
28+
_ => format!("Unknown action type {}", action_type),
29+
}
30+
}
31+
32+
/// Render the board state as a human-readable string.
33+
///
34+
/// Shows player positions as `1` and `2`, walls as `|` (vertical) and `-`
35+
/// (horizontal), and empty cells as `.`.
36+
///
37+
/// The board is always shown in the original (un-rotated) orientation.
38+
pub fn display_board(
39+
grid: &ndarray::ArrayView2<i8>,
40+
player_positions: &ndarray::ArrayView2<i32>,
41+
walls_remaining: &ndarray::ArrayView1<i32>,
42+
board_size: i32,
43+
) -> String {
44+
let mut out = String::new();
45+
let bs = board_size as usize;
46+
47+
// Column header
48+
out.push_str(" ");
49+
for c in 0..bs {
50+
out.push_str(&format!(" {} ", c));
51+
}
52+
out.push('\n');
53+
54+
let p0_row = player_positions[[0, 0]] as usize;
55+
let p0_col = player_positions[[0, 1]] as usize;
56+
let p1_row = player_positions[[1, 0]] as usize;
57+
let p1_col = player_positions[[1, 1]] as usize;
58+
59+
for row in 0..bs {
60+
// --- cell row ---
61+
out.push_str(&format!("{:>3} ", row));
62+
for col in 0..bs {
63+
// cell content
64+
if row == p0_row && col == p0_col {
65+
out.push('1');
66+
} else if row == p1_row && col == p1_col {
67+
out.push('2');
68+
} else {
69+
out.push('.');
70+
}
71+
72+
// vertical wall to the right
73+
if col < bs - 1 {
74+
// Grid coord of the gap between (row,col) and (row,col+1)
75+
let gr = (row * 2 + 2) as usize;
76+
let gc = (col * 2 + 3) as usize;
77+
if grid[[gr, gc]] == CELL_WALL {
78+
out.push_str(" | ");
79+
} else {
80+
out.push_str(" ");
81+
}
82+
}
83+
}
84+
// Metadata on the right of first two rows
85+
match row {
86+
0 => out.push_str(&format!(" P1 walls: {}", walls_remaining[0])),
87+
1 => out.push_str(&format!(" P2 walls: {}", walls_remaining[1])),
88+
_ => {}
89+
}
90+
out.push('\n');
91+
92+
// --- horizontal wall row between this row and the next ---
93+
if row < bs - 1 {
94+
out.push_str(" ");
95+
for col in 0..bs {
96+
// Grid coord of the gap between (row,col) and (row+1,col)
97+
let gr = (row * 2 + 3) as usize;
98+
let gc = (col * 2 + 2) as usize;
99+
if grid[[gr, gc]] == CELL_WALL {
100+
out.push('-');
101+
} else {
102+
out.push(' ');
103+
}
104+
if col < bs - 1 {
105+
out.push_str(" ");
106+
}
107+
}
108+
out.push('\n');
109+
}
110+
}
111+
out
112+
}
113+
19114
/// One turn's training data, stored in "current-player-faces-downward" coords.
20115
pub struct ReplayBufferItem {
21116
/// ResNet input tensor (5, M, M) — the batch dimension is squeezed out.
@@ -46,12 +141,17 @@ pub struct GameResult {
46141
/// Player 0 moves first. When Player 1 is the current player, the board is
47142
/// rotated 180° before being passed to `agent_p2` so the network always sees
48143
/// "current player moving downward".
144+
///
145+
/// When `trace` is `true`, each step prints whose turn it is, the action
146+
/// chosen, and the resulting board state in the original (un-rotated)
147+
/// orientation.
49148
pub fn play_game(
50149
agent_p1: &mut dyn ActionSelector,
51150
agent_p2: &mut dyn ActionSelector,
52151
board_size: i32,
53152
max_walls: i32,
54153
max_steps: i32,
154+
trace: bool,
55155
) -> anyhow::Result<GameResult> {
56156
let (mut grid, mut player_positions, mut walls_remaining, goal_rows) =
57157
create_initial_state(board_size, max_walls);
@@ -116,9 +216,7 @@ pub fn play_game(
116216
)?;
117217

118218
// Store replay item (in rotated frame — the frame the model saw)
119-
let input_3d = resnet_input
120-
.index_axis(ndarray::Axis(0), 0)
121-
.to_owned();
219+
let input_3d = resnet_input.index_axis(ndarray::Axis(0), 0).to_owned();
122220
replay_items.push(ReplayBufferItem {
123221
input_array: input_3d,
124222
policy: policy.clone(),
@@ -132,7 +230,12 @@ pub fn play_game(
132230

133231
// If Player 1, un-rotate action coordinates back to original frame
134232
let (a_row, a_col, a_type) = if current_player == 1 {
135-
rotate_action_coords(board_size, action_triple[0], action_triple[1], action_triple[2])
233+
rotate_action_coords(
234+
board_size,
235+
action_triple[0],
236+
action_triple[1],
237+
action_triple[2],
238+
)
136239
} else {
137240
(action_triple[0], action_triple[1], action_triple[2])
138241
};
@@ -147,6 +250,25 @@ pub fn play_game(
147250
&action_arr.view(),
148251
);
149252

253+
if trace {
254+
let player_label = if current_player == 0 { "P1" } else { "P2" };
255+
println!(
256+
"--- Step {} | {} ---\n{}",
257+
step + 1,
258+
player_label,
259+
format_action(board_size, a_row, a_col, a_type),
260+
);
261+
print!(
262+
"{}\n",
263+
display_board(
264+
&grid.view(),
265+
&player_positions.view(),
266+
&walls_remaining.view(),
267+
board_size
268+
),
269+
);
270+
}
271+
150272
// Check win
151273
if check_win(&player_positions.view(), &goal_rows.view(), current_player) {
152274
winner = Some(current_player);
@@ -209,7 +331,7 @@ mod tests {
209331
fn test_play_game_completes() {
210332
let mut p1 = FirstValidAgent;
211333
let mut p2 = FirstValidAgent;
212-
let result = play_game(&mut p1, &mut p2, 5, 3, 200).unwrap();
334+
let result = play_game(&mut p1, &mut p2, 5, 3, 200, false).unwrap();
213335

214336
// Game should complete within 200 steps on a 5×5 board
215337
assert!(result.num_turns > 0);
@@ -220,7 +342,7 @@ mod tests {
220342
fn test_play_game_alternating_players() {
221343
let mut p1 = FirstValidAgent;
222344
let mut p2 = FirstValidAgent;
223-
let result = play_game(&mut p1, &mut p2, 5, 0, 200).unwrap();
345+
let result = play_game(&mut p1, &mut p2, 5, 0, 200, false).unwrap();
224346

225347
// With 0 walls the game should end quickly via moves only
226348
// Players should alternate
@@ -233,7 +355,7 @@ mod tests {
233355
fn test_play_game_winner_values() {
234356
let mut p1 = FirstValidAgent;
235357
let mut p2 = FirstValidAgent;
236-
let result = play_game(&mut p1, &mut p2, 5, 0, 200).unwrap();
358+
let result = play_game(&mut p1, &mut p2, 5, 0, 200, false).unwrap();
237359

238360
if let Some(w) = result.winner {
239361
for item in &result.replay_items {
@@ -251,7 +373,7 @@ mod tests {
251373
let mut p1 = FirstValidAgent;
252374
let mut p2 = FirstValidAgent;
253375
// Very short max_steps to force truncation
254-
let result = play_game(&mut p1, &mut p2, 5, 3, 2).unwrap();
376+
let result = play_game(&mut p1, &mut p2, 5, 3, 2, false).unwrap();
255377

256378
if result.winner.is_none() {
257379
for item in &result.replay_items {
@@ -264,7 +386,7 @@ mod tests {
264386
fn test_replay_items_have_correct_shapes() {
265387
let mut p1 = FirstValidAgent;
266388
let mut p2 = FirstValidAgent;
267-
let result = play_game(&mut p1, &mut p2, 5, 3, 200).unwrap();
389+
let result = play_game(&mut p1, &mut p2, 5, 3, 200, false).unwrap();
268390

269391
let grid_size = 5 * 2 + 3; // 13
270392
let total_actions = policy_size(5);

deep_quoridor/rust/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ pub mod compact;
1111
pub mod game_state;
1212
pub mod grid;
1313
pub mod grid_helpers;
14-
pub mod rotation;
1514
mod minimax;
1615
mod pathfinding;
16+
pub mod rotation;
1717
mod validation;
1818

1919
pub mod agents;

deep_quoridor/rust/src/replay_writer.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,11 @@ pub fn write_game_npz<P: AsRef<Path>>(path: P, result: &GameResult) -> Result<()
6969
// Stack action_masks → (N, policy_size) as f32
7070
let mut mask_data = Vec::with_capacity(n * policy_len);
7171
for item in items {
72-
mask_data.extend(item.action_mask.iter().map(|&b| if b { 1.0f32 } else { 0.0f32 }));
72+
mask_data.extend(
73+
item.action_mask
74+
.iter()
75+
.map(|&b| if b { 1.0f32 } else { 0.0f32 }),
76+
);
7377
}
7478
let action_masks =
7579
Array2::<f32>::from_shape_vec((n, policy_len), mask_data).context("action_masks")?;

deep_quoridor/rust/src/rotation.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,7 @@
99
1010
use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
1111

12-
use crate::actions::{
13-
action_index_to_action, action_to_index, policy_size, ACTION_MOVE,
14-
};
12+
use crate::actions::{action_index_to_action, action_to_index, policy_size, ACTION_MOVE};
1513

1614
/// Rotate a 2D grid 180° — equivalent to `np.rot90(grid, k=2)`.
1715
///

0 commit comments

Comments
 (0)