Skip to content

Commit 279bef1

Browse files
perf(cfr): single shared tree, compact action space, regret-based pruning, and board enumeration (#274)
Implement four complementary optimizations from the CFR literature that together reduce memory usage by ~80% and improve convergence quality. Previously each player maintained an independent CFR tree (`Vec<CFRState>`). Since the traversal records all information (no information hiding), every player's tree was structurally identical — duplicating all nodes and regret data. This commit replaces the per-player vec with a single `CFRState` shared by all players. For a 2-player game this halves the NodeArena memory. The historian now creates nodes once and moves all traversal states together, and sub-agent construction no longer clones a vec of Arc-backed states. The action index mapper allocated 52 slots per regret matcher (fold, call, 49 raise buckets, all-in) but typical action generators produce only 4-8 distinct bet sizes. The 49 raise slots provided ~2% log-space resolution — far finer than needed. Reducing to 12 raise slots (indices 2-13) plus fold, call, all-in, and one reserved slot gives 16 total indices with ~38% resolution, still sufficient to distinguish standard pot-fraction bets (0.33x, 0.67x, 1.0x, 1.5x). Each regret matcher stores a weight vector sized to NUM_ACTION_INDICES, so this change reduces per-node regret storage by ~69%. After a warmup period (3 updates), actions whose cumulative regret has been driven to zero by PCFR+ clamping are skipped during reward computation. Computing rewards requires expensive recursive sub-simulations, so skipping dead actions saves significant wall-clock time. Every 4th iteration all actions are reprobed to detect actions that have become relevant again. The pruning integrates into both the sequential and parallel (rayon) reward computation paths. When computing fast-forward rewards at leaf depth, the previous approach sampled a single random board completion. This introduces variance into the reward signal, which slows CFR convergence. The new approach depends on how many community cards remain: - 0 cards (showdown): 1 evaluation — deterministic. - 1 card (river only): ~46 evaluations — full enumeration. - 2 cards (turn + river): ~C(46,2) ≈ 1035 evaluations — full enumeration. - 3 cards (flop): sample k=3 random flops, then enumerate all turn+river combinations for each (~946 evals per flop, ~2838 total). This hybrid approach gives 3.8x variance reduction vs single-sample at only 2x the cost of k=1. Testing k=1,2,3,5,8,13 showed diminishing returns beyond k=3 (each additional sample adds <0.5x reduction while doubling cost). The 0-2 card cases produce zero-variance reward signals. The 3-card hybrid approach substantially reduces variance from the flop while keeping cost tractable for the CFR inner loop. Related to the AIVAT variance reduction technique (Burch et al., AAAI 2018). - `CFRAgentBuilder::cfr_states(Vec<CFRState>)` → `cfr_state(CFRState)` - `HoldemSimulationBuilder::cfr_context_arc()` removed; `cfr_context()` now takes a single `CFRState` instead of `Vec<CFRState>` - `CFRHistorian::new()` takes `&CFRState` instead of `&Arc<[CFRState]>` - `ConfigAgentBuilder::cfr_context()` takes `CFRState` instead of vec - `test_rbp_preserves_fold_decision`: verifies pruning still produces the correct fold for K-high facing an all-in on a paired board. - `test_rbp_reduces_active_actions`: verifies the pruning bitset becomes sparse after warmup, confirming RBP activates and prunes dead actions. - `test_flop_sample_variance_vs_single_sample`: confirms 3.8x variance reduction with k=3 flop samples vs single random runout. - `test_flop_sample_dominated_hand`: AKs vs 72o produces correct positive EV (~160) matching theoretical 66% equity. - `test_flop_sample_count_comparison`: parameterized comparison across k=1,2,3,5,8,13 showing cost/variance tradeoff (k=3 is the knee). - All existing CFR tests updated for the single-state API.
1 parent 6d74867 commit 279bef1

15 files changed

Lines changed: 1195 additions & 386 deletions

File tree

benches/cfr.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ fn run_cfr_configurable_arena(num_hands: usize) -> GameState {
6565
.build()
6666
.unwrap();
6767

68-
let cfr_states: Vec<CFRState> = (0..2).map(|_| CFRState::new(game_state.clone())).collect();
68+
let cfr_state = CFRState::new(game_state.clone());
6969
let traversal_set = TraversalSet::new(2);
7070
let builder = ConfigAgentBuilder::from_json(&json).expect("Failed to parse CFR config");
7171

@@ -75,15 +75,15 @@ fn run_cfr_configurable_arena(num_hands: usize) -> GameState {
7575
.clone()
7676
.player_idx(idx)
7777
.game_state(game_state.clone())
78-
.cfr_context(cfr_states.clone(), traversal_set.clone())
78+
.cfr_context(cfr_state.clone(), traversal_set.clone())
7979
.build()
8080
})
8181
.collect();
8282

8383
let mut sim = HoldemSimulationBuilder::default()
8484
.game_state(game_state)
8585
.agents(agents)
86-
.cfr_context(cfr_states, traversal_set, true)
86+
.cfr_context(cfr_state, traversal_set, true)
8787
.build()
8888
.unwrap();
8989

@@ -100,7 +100,7 @@ fn run_cfr_configurable_arena_default() -> GameState {
100100
.build()
101101
.unwrap();
102102

103-
let cfr_states: Vec<CFRState> = (0..2).map(|_| CFRState::new(game_state.clone())).collect();
103+
let cfr_state = CFRState::new(game_state.clone());
104104
let traversal_set = TraversalSet::new(2);
105105
let builder =
106106
ConfigAgentBuilder::from_json(CFR_CONFIGURABLE_JSON).expect("Failed to parse CFR config");
@@ -111,15 +111,15 @@ fn run_cfr_configurable_arena_default() -> GameState {
111111
.clone()
112112
.player_idx(idx)
113113
.game_state(game_state.clone())
114-
.cfr_context(cfr_states.clone(), traversal_set.clone())
114+
.cfr_context(cfr_state.clone(), traversal_set.clone())
115115
.build()
116116
})
117117
.collect();
118118

119119
let mut sim = HoldemSimulationBuilder::default()
120120
.game_state(game_state)
121121
.agents(agents)
122-
.cfr_context(cfr_states, traversal_set, true)
122+
.cfr_context(cfr_state, traversal_set, true)
123123
.build()
124124
.unwrap();
125125

fuzz/fuzz_targets/cfr_mixed_agents.rs

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,8 @@ fuzz_target!(|input: MixedAgentInput| {
6262
.build()
6363
.unwrap();
6464

65-
// Create shared CFR context
66-
let cfr_states: Vec<CFRState> = (0..game_state.num_players)
67-
.map(|_| CFRState::new(game_state.clone()))
68-
.collect();
65+
// Create shared CFR context (single tree for all players)
66+
let cfr_state = CFRState::new(game_state.clone());
6967
let traversal_set = TraversalSet::new(game_state.num_players);
7068

7169
let agents: Vec<Box<dyn Agent>> = vec![
@@ -74,7 +72,7 @@ fuzz_target!(|input: MixedAgentInput| {
7472
input.cfr_indices[0],
7573
input.cfr_variants[0],
7674
&input.player0_actions,
77-
&cfr_states,
75+
&cfr_state,
7876
&traversal_set,
7977
&depth_hands,
8078
),
@@ -83,7 +81,7 @@ fuzz_target!(|input: MixedAgentInput| {
8381
input.cfr_indices[1],
8482
input.cfr_variants[1],
8583
&input.player1_actions,
86-
&cfr_states,
84+
&cfr_state,
8785
&traversal_set,
8886
&depth_hands,
8987
),
@@ -93,7 +91,7 @@ fuzz_target!(|input: MixedAgentInput| {
9391
let mut sim = HoldemSimulationBuilder::default()
9492
.game_state(game_state)
9593
.agents(agents)
96-
.cfr_context(cfr_states, traversal_set, true)
94+
.cfr_context(cfr_state, traversal_set, true)
9795
.build()
9896
.unwrap();
9997

@@ -111,7 +109,7 @@ fn create_agent(
111109
is_cfr: bool,
112110
cfr_variant: CfrVariant,
113111
actions: &[AgentAction],
114-
cfr_states: &[CFRState],
112+
cfr_state: &CFRState,
115113
traversal_set: &TraversalSet,
116114
depth_hands: &[usize],
117115
) -> Box<dyn Agent> {
@@ -122,7 +120,7 @@ fn create_agent(
122120
CFRAgentBuilder::<SimpleActionGenerator>::new()
123121
.name(format!("CFRAgent-{player_idx}"))
124122
.player_idx(player_idx)
125-
.cfr_states(cfr_states.to_vec())
123+
.cfr_state(cfr_state.clone())
126124
.traversal_set(traversal_set.clone())
127125
.depth_config(depth_config)
128126
.action_gen_config(())
@@ -147,7 +145,7 @@ fn create_agent(
147145
CFRAgentBuilder::<ConfigurableActionGenerator>::new()
148146
.name(format!("CFRConfigurableAgent-{player_idx}"))
149147
.player_idx(player_idx)
150-
.cfr_states(cfr_states.to_vec())
148+
.cfr_state(cfr_state.clone())
151149
.traversal_set(traversal_set.clone())
152150
.depth_config(depth_config)
153151
.action_gen_config(config)

fuzz/fuzz_targets/config_agent.rs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -175,11 +175,9 @@ fuzz_target!(|input: ConfigAgentInput| {
175175
// Check if any agent is CFR-based; if so create shared context
176176
let has_cfr = input.player_configs.iter().any(|c| c.is_cfr());
177177
let cfr_context = if has_cfr {
178-
let cfr_states: Vec<CFRState> = (0..game_state.num_players)
179-
.map(|_| CFRState::new(game_state.clone()))
180-
.collect();
178+
let cfr_state = CFRState::new(game_state.clone());
181179
let traversal_set = TraversalSet::new(game_state.num_players);
182-
Some((cfr_states, traversal_set))
180+
Some((cfr_state, traversal_set))
183181
} else {
184182
None
185183
};
@@ -194,8 +192,8 @@ fuzz_target!(|input: ConfigAgentInput| {
194192
let mut builder = ConfigAgentBuilder::new(config.clone())?
195193
.player_idx(idx)
196194
.game_state(game_state.clone());
197-
if let Some((ref cfr_states, ref ts)) = cfr_context {
198-
builder = builder.cfr_context(cfr_states.clone(), ts.clone());
195+
if let Some((ref cfr_state, ref ts)) = cfr_context {
196+
builder = builder.cfr_context(cfr_state.clone(), ts.clone());
199197
}
200198
Ok(builder.build())
201199
})
@@ -219,8 +217,8 @@ fuzz_target!(|input: ConfigAgentInput| {
219217
.game_state(game_state)
220218
.agents(agents)
221219
.historians(historians);
222-
if let Some((cfr_states, traversal_set)) = cfr_context {
223-
sim_builder = sim_builder.cfr_context(cfr_states, traversal_set, true);
220+
if let Some((cfr_state, traversal_set)) = cfr_context {
221+
sim_builder = sim_builder.cfr_context(cfr_state, traversal_set, true);
224222
}
225223
let mut sim: HoldemSimulation = sim_builder.build().unwrap();
226224
sim.run(&mut rng);

src/arena/agent/config.rs

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ pub struct ConfigAgentBuilder {
442442
config: AgentConfig,
443443
player_idx: Option<usize>,
444444
game_state: Option<GameState>,
445-
cfr_states: Option<Vec<CFRState>>,
445+
cfr_state: Option<CFRState>,
446446
traversal_set: Option<TraversalSet>,
447447
thread_pool: Option<Arc<rayon::ThreadPool>>,
448448
rng_seed: Option<u64>,
@@ -456,7 +456,7 @@ impl ConfigAgentBuilder {
456456
config,
457457
player_idx: None,
458458
game_state: None,
459-
cfr_states: None,
459+
cfr_state: None,
460460
traversal_set: None,
461461
thread_pool: None,
462462
rng_seed: None,
@@ -471,18 +471,15 @@ impl ConfigAgentBuilder {
471471

472472
/// Set the game state for the agent.
473473
///
474-
/// For CFR agents, this eagerly initializes CFR states and a
474+
/// For CFR agents, this eagerly initializes a shared CFR state and a
475475
/// `TraversalSet` (unless explicit context was already provided via
476476
/// `cfr_context()`). This means cloned builders automatically share
477477
/// the same CFR state (CFRState is cheap to clone via Arc).
478478
pub fn game_state(mut self, game_state: GameState) -> Self {
479479
// Eagerly create CFR context so that cloned builders share the
480-
// same Arc-backed states. Explicit cfr_context() takes priority.
481-
if self.config.is_cfr() && self.cfr_states.is_none() {
482-
let cfr_states: Vec<CFRState> = (0..game_state.num_players)
483-
.map(|_| CFRState::new(game_state.clone()))
484-
.collect();
485-
self.cfr_states = Some(cfr_states);
480+
// same Arc-backed state. Explicit cfr_context() takes priority.
481+
if self.config.is_cfr() && self.cfr_state.is_none() {
482+
self.cfr_state = Some(CFRState::new(game_state.clone()));
486483
self.traversal_set = Some(TraversalSet::new(game_state.num_players));
487484
}
488485
self.game_state = Some(game_state);
@@ -494,8 +491,8 @@ impl ConfigAgentBuilder {
494491
/// When set, all CFR agents built will share the same CFR states
495492
/// and `TraversalSet`, enabling shared learning across agents.
496493
/// When not set, each CFR agent creates its own.
497-
pub fn cfr_context(mut self, cfr_states: Vec<CFRState>, traversal_set: TraversalSet) -> Self {
498-
self.cfr_states = Some(cfr_states);
494+
pub fn cfr_context(mut self, cfr_state: CFRState, traversal_set: TraversalSet) -> Self {
495+
self.cfr_state = Some(cfr_state);
499496
self.traversal_set = Some(traversal_set);
500497
self
501498
}
@@ -606,24 +603,24 @@ impl ConfigAgentBuilder {
606603
}
607604
}
608605
AgentConfig::CfrBasic { name, depth_hands } => {
609-
let (cfr_states, traversal_set) = self.resolve_cfr_context();
606+
let (cfr_state, traversal_set) = self.resolve_cfr_context();
610607
let depth_config = CfrDepthConfig::new(depth_hands.clone());
611608
let builder = CFRAgentBuilder::<BasicCFRActionGenerator>::new()
612609
.name(resolve_agent_name(name, "CFRAgent", player_idx))
613610
.player_idx(player_idx)
614-
.cfr_states(cfr_states)
611+
.cfr_state(cfr_state)
615612
.traversal_set(traversal_set)
616613
.depth_config(depth_config)
617614
.action_gen_config(());
618615
Box::new(self.apply_cfr_options(builder).build())
619616
}
620617
AgentConfig::CfrSimple { name, depth_hands } => {
621-
let (cfr_states, traversal_set) = self.resolve_cfr_context();
618+
let (cfr_state, traversal_set) = self.resolve_cfr_context();
622619
let depth_config = CfrDepthConfig::new(depth_hands.clone());
623620
let builder = CFRAgentBuilder::<SimpleActionGenerator>::new()
624621
.name(resolve_agent_name(name, "CFRSimpleAgent", player_idx))
625622
.player_idx(player_idx)
626-
.cfr_states(cfr_states)
623+
.cfr_state(cfr_state)
627624
.traversal_set(traversal_set)
628625
.depth_config(depth_config)
629626
.action_gen_config(());
@@ -634,12 +631,12 @@ impl ConfigAgentBuilder {
634631
depth_hands,
635632
action_config,
636633
} => {
637-
let (cfr_states, traversal_set) = self.resolve_cfr_context();
634+
let (cfr_state, traversal_set) = self.resolve_cfr_context();
638635
let depth_config = CfrDepthConfig::new(depth_hands.clone());
639636
let builder = CFRAgentBuilder::<ConfigurableActionGenerator>::new()
640637
.name(resolve_agent_name(name, "CFRConfigurableAgent", player_idx))
641638
.player_idx(player_idx)
642-
.cfr_states(cfr_states)
639+
.cfr_state(cfr_state)
643640
.traversal_set(traversal_set)
644641
.depth_config(depth_config)
645642
.action_gen_config(action_config.as_ref().clone());
@@ -654,7 +651,7 @@ impl ConfigAgentBuilder {
654651
let resolved_preflop_config = preflop_config
655652
.resolve()
656653
.expect("Invalid preflop config - should have been validated");
657-
let (cfr_states, traversal_set) = self.resolve_cfr_context();
654+
let (cfr_state, traversal_set) = self.resolve_cfr_context();
658655
let depth_config = CfrDepthConfig::new(depth_hands.clone());
659656
let action_config = PreflopChartActionConfig {
660657
preflop_config: resolved_preflop_config,
@@ -666,7 +663,7 @@ impl ConfigAgentBuilder {
666663
let builder = CFRAgentBuilder::<PreflopChartActionGenerator>::new()
667664
.name(resolve_agent_name(name, "CFRPreflopChartAgent", player_idx))
668665
.player_idx(player_idx)
669-
.cfr_states(cfr_states)
666+
.cfr_state(cfr_state)
670667
.traversal_set(traversal_set)
671668
.depth_config(depth_config)
672669
.action_gen_config(action_config);
@@ -698,16 +695,16 @@ impl ConfigAgentBuilder {
698695
/// # Panics
699696
///
700697
/// Panics if neither `cfr_context()` nor `game_state()` was called.
701-
fn resolve_cfr_context(&self) -> (Vec<CFRState>, TraversalSet) {
702-
let cfr_states = self
703-
.cfr_states
698+
fn resolve_cfr_context(&self) -> (CFRState, TraversalSet) {
699+
let cfr_state = self
700+
.cfr_state
704701
.clone()
705702
.expect("cfr_context() or game_state() is required for CFR agents");
706703
let traversal_set = self
707704
.traversal_set
708705
.clone()
709706
.expect("cfr_context() or game_state() is required for CFR agents");
710-
(cfr_states, traversal_set)
707+
(cfr_state, traversal_set)
711708
}
712709
}
713710

src/arena/cfr/action_bit_set.rs

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
1-
/// A bit set for tracking action indices (0-51).
1+
/// A bit set for tracking action indices (0-15) and card indices (0-51).
22
///
3-
/// This is optimized for the CFR action space which has 52 possible action indices:
4-
/// - Index 0: Fold
5-
/// - Index 1: Call/Check
6-
/// - Indices 2-50: Raises (logarithmic distribution)
7-
/// - Index 51: All-in
8-
///
9-
/// Using a `u64` allows O(1) insert and contains operations with no heap allocation,
10-
/// which is much faster than `HashSet<usize>` for this use case.
3+
/// Used for the CFR action space (16 action indices) and for card tracking.
4+
/// Using a `u64` allows O(1) insert and contains operations with no heap
5+
/// allocation, supporting up to 64 indices.
116
#[derive(Default, Clone, Copy, PartialEq, Eq, Debug)]
127
pub struct ActionBitSet {
138
bits: u64,
@@ -30,7 +25,7 @@ impl ActionBitSet {
3025
/// Panics in debug mode if `idx >= 52`.
3126
#[inline]
3227
pub fn insert(&mut self, idx: usize) -> bool {
33-
debug_assert!(idx < 52, "Action index must be < 52, got {}", idx);
28+
debug_assert!(idx < 64, "Bit index must be < 64, got {}", idx);
3429
let mask = 1u64 << idx;
3530
let was_present = (self.bits & mask) != 0;
3631
self.bits |= mask;
@@ -40,7 +35,7 @@ impl ActionBitSet {
4035
/// Returns `true` if the set contains the given action index.
4136
#[inline]
4237
pub fn contains(&self, idx: usize) -> bool {
43-
debug_assert!(idx < 52, "Action index must be < 52, got {}", idx);
38+
debug_assert!(idx < 64, "Bit index must be < 64, got {}", idx);
4439
(self.bits & (1u64 << idx)) != 0
4540
}
4641

src/arena/cfr/action_generator/preflop_chart.rs

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -747,18 +747,16 @@ mod tests {
747747
let config = create_simple_config();
748748
let depth_config = CfrDepthConfig::new(vec![1]);
749749

750-
// Create CFR agents with PreflopChartActionGenerator sharing the same CFR states
751-
let cfr_states: Vec<CFRState> = (0..game_state.num_players)
752-
.map(|_| CFRState::new(game_state.clone()))
753-
.collect();
750+
// Create CFR agents with PreflopChartActionGenerator sharing a single CFR state
751+
let cfr_state = CFRState::new(game_state.clone());
754752
let traversal_set = TraversalSet::new(game_state.num_players);
755753
let agents: Vec<Box<dyn Agent>> = (0..2)
756754
.map(|idx| {
757755
Box::new(
758756
CFRAgentBuilder::<PreflopChartActionGenerator>::new()
759757
.name(format!("PreflopChartAgent-{idx}"))
760758
.player_idx(idx)
761-
.cfr_states(cfr_states.clone())
759+
.cfr_state(cfr_state.clone())
762760
.traversal_set(traversal_set.clone())
763761
.depth_config(depth_config.clone())
764762
.action_gen_config(config.clone())
@@ -772,7 +770,7 @@ mod tests {
772770
let mut sim = HoldemSimulationBuilder::default()
773771
.game_state(game_state)
774772
.agents(agents)
775-
.cfr_context(cfr_states, traversal_set, true)
773+
.cfr_context(cfr_state, traversal_set, true)
776774
.build()
777775
.unwrap();
778776

@@ -800,17 +798,15 @@ mod tests {
800798
.build()
801799
.unwrap();
802800

803-
let cfr_states: Vec<CFRState> = (0..game_state.num_players)
804-
.map(|_| CFRState::new(game_state.clone()))
805-
.collect();
801+
let cfr_state = CFRState::new(game_state.clone());
806802
let traversal_set = TraversalSet::new(game_state.num_players);
807803
let agents: Vec<Box<dyn Agent>> = (0..2)
808804
.map(|idx| {
809805
Box::new(
810806
CFRAgentBuilder::<PreflopChartActionGenerator>::new()
811807
.name(format!("PreflopChartAgent-game{game_idx}-p{idx}"))
812808
.player_idx(idx)
813-
.cfr_states(cfr_states.clone())
809+
.cfr_state(cfr_state.clone())
814810
.traversal_set(traversal_set.clone())
815811
.depth_config(depth_config.clone())
816812
.action_gen_config(config.clone())
@@ -824,7 +820,7 @@ mod tests {
824820
let mut sim = HoldemSimulationBuilder::default()
825821
.game_state(game_state)
826822
.agents(agents)
827-
.cfr_context(cfr_states, traversal_set, true)
823+
.cfr_context(cfr_state, traversal_set, true)
828824
.build()
829825
.unwrap();
830826

0 commit comments

Comments
 (0)