Skip to content

Commit f9e6276

Browse files
fix(cfr): include remaining stack in board enumeration rewards and guard sequential pruning (#275)
1 parent 279bef1 commit f9e6276

1 file changed

Lines changed: 40 additions & 9 deletions

File tree

src/arena/cfr/agent.rs

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -797,7 +797,17 @@ where
797797

798798
// Decide whether to prune this iteration.
799799
// On reprobe iterations (every REPROBE_INTERVAL-th), explore all actions.
800-
let prune_this_iter = can_prune || updates_since_warmup >= PRUNE_WARMUP;
800+
//
801+
// The len() > 2 guard is required here even though `can_prune`
802+
// already checks it. The second disjunct (`updates_since_warmup
803+
// >= PRUNE_WARMUP`) handles nodes that cross the warmup
804+
// threshold mid-call, but it does not carry the action-count
805+
// check. Without the outer guard, 2-action nodes could have one
806+
// action pruned on 75% of iterations, collapsing to a fixed
807+
// policy with no exploration. This matches the parallel path
808+
// (L767) which gates all pruning on `can_prune`.
809+
let prune_this_iter = indexed_actions.len() > 2
810+
&& (can_prune || updates_since_warmup >= PRUNE_WARMUP);
801811
let is_reprobe = iter_idx % REPROBE_INTERVAL == 0;
802812
let skip_pruned = prune_this_iter && !is_reprobe;
803813

@@ -829,8 +839,13 @@ where
829839
updates_since_warmup += 1;
830840

831841
// After a reprobe iteration, refresh the active action set
832-
// from the updated regret matcher.
833-
if is_reprobe && (can_prune || updates_since_warmup >= PRUNE_WARMUP) {
842+
// from the updated regret matcher. The len() > 2 guard keeps
843+
// this consistent with the pruning decision above — there is
844+
// no point refreshing an active set we will never use.
845+
if is_reprobe
846+
&& indexed_actions.len() > 2
847+
&& (can_prune || updates_since_warmup >= PRUNE_WARMUP)
848+
{
834849
let (new_active, _) = self.cfr_state.get_pruning_info(target_node_idx);
835850
active_actions = new_active;
836851
}
@@ -1047,11 +1062,13 @@ fn fast_forward_enumerate_showdowns(gs: &GameState, player_idx: usize, cards_nee
10471062
}
10481063

10491064
// Single contender: they win everything regardless of board.
1065+
// Include remaining_stack because fast_forward_advance_betting has already
1066+
// moved some chips from stacks into the pot.
10501067
if contender_count == 1 {
10511068
let winner = contenders.ones().next().unwrap();
10521069
let pot = gs.total_pot;
10531070
let winnings = if winner == player_idx { pot } else { 0.0 };
1054-
return winnings - gs.starting_stacks[player_idx];
1071+
return gs.stacks[player_idx] + winnings - gs.starting_stacks[player_idx];
10551072
}
10561073

10571074
let pot = gs.total_pot;
@@ -1069,14 +1086,21 @@ fn fast_forward_enumerate_showdowns(gs: &GameState, player_idx: usize, cards_nee
10691086
}
10701087

10711088
let starting_stack = gs.starting_stacks[player_idx];
1089+
// After fast_forward_advance_betting, chips have moved from stacks into
1090+
// the pot. `evaluate_with_extra_cards` returns only the player's share of
1091+
// the pot (or 0), so the net reward is:
1092+
// remaining_stack + pot_share - starting_stack
1093+
// The remaining_stack term accounts for the chips the player kept — without
1094+
// it the reward would be off by exactly the unbet portion of their stack.
1095+
let remaining_stack = gs.stacks[player_idx];
10721096
let mut total_reward = 0.0f64;
10731097
let mut count = 0u32;
10741098

10751099
if cards_needed == 1 {
10761100
// Enumerate single card (river).
10771101
for &card in &remaining {
10781102
let reward = evaluate_with_extra_cards(gs, &contenders, pot, player_idx, &[card]);
1079-
total_reward += f64::from(reward - starting_stack);
1103+
total_reward += f64::from(remaining_stack + reward - starting_stack);
10801104
count += 1;
10811105
}
10821106
} else {
@@ -1092,7 +1116,7 @@ fn fast_forward_enumerate_showdowns(gs: &GameState, player_idx: usize, cards_nee
10921116
player_idx,
10931117
&[remaining[i], remaining[j]],
10941118
);
1095-
total_reward += f64::from(reward - starting_stack);
1119+
total_reward += f64::from(remaining_stack + reward - starting_stack);
10961120
count += 1;
10971121
}
10981122
}
@@ -1133,14 +1157,16 @@ fn fast_forward_sample_flop_enumerate_runout_n<R: Rng>(
11331157
let contenders = gs.player_active | gs.player_all_in;
11341158
let contender_count = contenders.count();
11351159

1160+
// Single/no contender: include remaining_stack because
1161+
// fast_forward_advance_betting has already moved chips into the pot.
11361162
if contender_count <= 1 {
11371163
if contender_count == 0 {
11381164
return gs.player_reward(player_idx);
11391165
}
11401166
let winner = contenders.ones().next().unwrap();
11411167
let pot = gs.total_pot;
11421168
let winnings = if winner == player_idx { pot } else { 0.0 };
1143-
return winnings - gs.starting_stacks[player_idx];
1169+
return gs.stacks[player_idx] + winnings - gs.starting_stacks[player_idx];
11441170
}
11451171

11461172
let pot = gs.total_pot;
@@ -1150,6 +1176,9 @@ fn fast_forward_sample_flop_enumerate_runout_n<R: Rng>(
11501176

11511177
let mut deck = fast_forward_remaining_deck(gs);
11521178
let starting_stack = gs.starting_stacks[player_idx];
1179+
// See comment in fast_forward_enumerate_showdowns — remaining_stack
1180+
// accounts for unbet chips after fast_forward_advance_betting.
1181+
let remaining_stack = gs.stacks[player_idx];
11531182
let mut total_reward = 0.0f64;
11541183
let mut total_count = 0u64;
11551184

@@ -1184,7 +1213,7 @@ fn fast_forward_sample_flop_enumerate_runout_n<R: Rng>(
11841213
player_idx,
11851214
&[remaining[i], remaining[j]],
11861215
);
1187-
total_reward += f64::from(reward - starting_stack);
1216+
total_reward += f64::from(remaining_stack + reward - starting_stack);
11881217
total_count += 1;
11891218
}
11901219
}
@@ -1258,14 +1287,16 @@ fn evaluate_with_extra_cards(
12581287
}
12591288

12601289
/// Evaluate showdown with the current board (no extra cards).
1290+
/// Returns `remaining_stack + pot_share - starting_stack` to account for
1291+
/// chips already moved from stacks into the pot by `fast_forward_advance_betting`.
12611292
fn evaluate_showdown_reward(
12621293
gs: &GameState,
12631294
contenders: &PlayerBitSet,
12641295
pot: f32,
12651296
player_idx: usize,
12661297
) -> f32 {
12671298
let reward = evaluate_with_extra_cards(gs, contenders, pot, player_idx, &[]);
1268-
reward - gs.starting_stacks[player_idx]
1299+
gs.stacks[player_idx] + reward - gs.starting_stacks[player_idx]
12691300
}
12701301

12711302
impl<T, R> Agent for CFRAgent<T, R>

0 commit comments

Comments
 (0)