diff --git a/rlcard/agents/bauernskat/dmc_agent/agent.py b/rlcard/agents/bauernskat/dmc_agent/agent.py new file mode 100644 index 000000000..099f5adf4 --- /dev/null +++ b/rlcard/agents/bauernskat/dmc_agent/agent.py @@ -0,0 +1,224 @@ +''' + File name: rlcard/games/bauernskat/dmc_agent/agent.py + Author: Oliver Czerwinski + Date created: 08/13/2025 + Date last modified: 12/26/2025 + Python Version: 3.9+ +''' + +import random +from collections import Counter +from typing import Dict, Tuple, List, Optional + +import numpy as np +import torch +import torch.nn.functional as F +from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts + +from rlcard.envs.env import Env +from rlcard.games.bauernskat.action_event import ActionEvent, DeclareTrumpAction +from rlcard.games.bauernskat.card import BauernskatCard +from rlcard.agents.bauernskat import rule_agents as bauernskat_rule_agents + +from rlcard.agents.bauernskat.dmc_agent.model import BauernskatNet +from rlcard.agents.bauernskat.dmc_agent.config import BauernskatNetConfig + +class Estimator: + """ + Q-value estimator using a neural network. + """ + + def __init__(self, net_config: BauernskatNetConfig, learning_rate: float, lr_gamma: float, device: torch.device, weight_decay: float = 1e-6, cosine_T0: int = 51_200_000, cosine_T_mult: int = 2, cosine_eta_min: float = 3e-6): + """ + Initializes Estimator. + """ + + self.device = device + self.qnet = BauernskatNet(net_config).to(device) + + self.optimizer = torch.optim.AdamW(self.qnet.parameters(), lr=learning_rate, weight_decay=weight_decay) + + self.scheduler = CosineAnnealingWarmRestarts( + self.optimizer, + T_0=cosine_T0, + T_mult=cosine_T_mult, + eta_min=cosine_eta_min + ) + + def train_step(self, states: dict, actions: dict, targets: torch.Tensor, clip_norm: float) -> float: + """ + Performs a training step on a batch of state-action pairs and targets. + """ + + self.qnet.train() + predicted_q_values = self.qnet(states, actions) + + loss = F.mse_loss(predicted_q_values, targets) + self.optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(self.qnet.parameters(), clip_norm) + self.optimizer.step() + self.scheduler.step() + return loss.item() + + def predict_nograd(self, states: dict, actions: dict) -> torch.Tensor: + """ + Predicts Q-values for a batch of state-action pairs without gradient updates. + """ + + self.qnet.eval() + with torch.no_grad(): + return self.qnet(states, actions) + + +class AgentDMC_Actor: + """ + An agent that uses an action-in Q-value network for decision making. + """ + + def __init__(self, net: BauernskatNet, device: str = 'cpu', use_teacher: bool = False): + """ + Initializes AgentDMC_Actor. + """ + + self.use_raw = True + self.net = net + self.device = device + + # Optional teacher + self.teacher = bauernskat_rule_agents.BauernskatLookaheadRuleAgent() if use_teacher else None + + @staticmethod + def _get_action_obs(action_id: int) -> Dict[str, List[int]]: + """ + Creates an action_obs dictionary for a given action_id. + """ + + if action_id >= 5: + return {'action_card_ids': [action_id - 5], 'trump_action_id': [-1]} + else: + return {'action_card_ids': [-1], 'trump_action_id': [action_id]} + + def _get_best_action(self, state: dict, env: Env) -> int: + """ + Encodes all legal actions at once and selects the one with the highest Q-value. + """ + legal_actions = list(state['legal_actions'].keys()) + if not legal_actions: + return -1 + + with torch.no_grad(): + state_batch = {k: torch.from_numpy(np.array(v)).unsqueeze(0).to(self.device) + for k, v in state['obs'].items()} + + state_encoding = self.net.encode_state(state_batch) + + action_obs_list = [self._get_action_obs(a) for a in legal_actions] + card_ids = [ao['action_card_ids'] for ao in action_obs_list] + trump_ids = [ao['trump_action_id'] for ao in action_obs_list] + action_batch = { + 'action_card_ids': torch.LongTensor(card_ids).to(self.device), + 'trump_action_id': torch.LongTensor(trump_ids).to(self.device) + } + + repeated_state_encoding = state_encoding.repeat(len(legal_actions), 1) + + q_values = self.net.predict_q(repeated_state_encoding, action_batch).cpu().numpy() + + return legal_actions[np.argmax(q_values)] + + def _get_rule_based_trump_action(self, state: dict) -> Optional[int]: + """ + Heuristic for trump selection: + - If 2 or more Jacks: Declare Grand ('G') + - Else: Suit with most cards. + - Otherwise: Suit with highest rank card. + """ + + legal_action_ids = list(state['legal_actions'].keys()) + raw_info = state.get('raw_state_info') + if not raw_info: return None + + my_cards = raw_info.get('my_cards', []) + if not my_cards: return random.choice(legal_action_ids) + + # Jacks + num_jacks = sum(1 for card in my_cards if card.rank == 'J') + if num_jacks >= 2: + for action_id in legal_action_ids: + action = ActionEvent.from_action_id(action_id) + if isinstance(action, DeclareTrumpAction) and action.trump_suit == 'G': + return action_id + + # Suit counts + suit_counts = Counter(card.suit for card in my_cards) + if not suit_counts: + return random.choice(legal_action_ids) + + max_count = max(suit_counts.values()) + best_suits = [suit for suit, count in suit_counts.items() if count == max_count] + + # Highest rank suit card + if len(best_suits) == 1: + best_suit_choice = best_suits[0] + else: + best_rank_val = -1 + + rank_order = BauernskatCard.ranks + best_suit_choice = best_suits[0] + for suit in best_suits: + cards_of_suit = [card for card in my_cards if card.suit == suit] + if cards_of_suit: + max_rank_in_suit = max(rank_order.index(c.rank) for c in cards_of_suit) + if max_rank_in_suit > best_rank_val: + best_rank_val = max_rank_in_suit + best_suit_choice = suit + + for action_id in legal_action_ids: + action = ActionEvent.from_action_id(action_id) + if isinstance(action, DeclareTrumpAction) and action.trump_suit == best_suit_choice: + return action_id + + return random.choice(legal_action_ids) + + def step(self, state: dict, env: Env, epsilon: float = 0.0, trump_rule_prob: float = 0.0, teacher_epsilon: float = 0.0) -> Tuple[int, Dict]: + """ + Chooses an action based on epsilon-greedy strategy with optional teacher forcing and rule-based trump selection. + """ + + r = random.random() + + # Teacher Forcing + if self.teacher is not None and r < teacher_epsilon: + action = self.teacher.step(state) + action_obs = self._get_action_obs(action) + return action, action_obs + + # Rule-Based Trump Selection + if trump_rule_prob > 0.0: + raw_info = state.get('raw_state_info') + if raw_info and raw_info.get('round_phase') == 'declare_trump': + if random.random() < trump_rule_prob: + action = self._get_rule_based_trump_action(state) + if action is not None: + action_obs = self._get_action_obs(action) + return action, action_obs + + # Random Exploration + if r < epsilon: + action = random.choice(list(state['legal_actions'].keys())) + else: + # Greedy + action = self._get_best_action(state, env) + + action_obs = self._get_action_obs(action) + return action, action_obs + + def eval_step(self, state: dict, env: Env) -> Tuple[int, Dict]: + """ + Chooses the best action without exploration. + """ + with torch.no_grad(): + action, _ = self.step(state, env, epsilon=0.0, trump_rule_prob=0.0, teacher_epsilon=0.0) + + return action, {} \ No newline at end of file diff --git a/rlcard/agents/bauernskat/dmc_agent/config.py b/rlcard/agents/bauernskat/dmc_agent/config.py new file mode 100644 index 000000000..c6470b431 --- /dev/null +++ b/rlcard/agents/bauernskat/dmc_agent/config.py @@ -0,0 +1,150 @@ +''' + File name: rlcard/games/bauernskat/dmc_agent/config.py + Author: Oliver Czerwinski + Date created: 08/13/2025 + Date last modified: 12/26/2025 + Python Version: 3.9+ +''' + +from dataclasses import dataclass, field +from typing import Tuple, Literal +import torch + +# Bauernskat specific constants +MAX_PLAYER_CARDS = 16 +MAX_TRICK_SIZE = 2 +MAX_CEMETERY_SIZE = 32 + +# Model architecture +@dataclass +class BauernskatNetConfig: + """ + Configuration for the BauernskatNet model. + """ + + card_embedding_dim: int = 32 + branch_output_dim: int = 96 + + pool_type: Literal['mean', 'sum'] = 'mean' + + mlp_hidden_dims: Tuple[int, ...] = (64, 64) + indicator_mlp_dims: Tuple[int, ...] = (64, 64) + layout_processor_hidden_dim: int = 128 + mask_processor_hidden_dims: Tuple[int, ...] = (64,) + + num_lstm_layers: int = 2 + lstm_hidden_dim: int = 96 + use_bidirectional: bool = True + use_attention: bool = True + attn_heads: int = 4 + lstm_fc_dims: Tuple[int, ...] = (96,) + + context_vector_dim: int = 11 + indicator_vector_dim: int = 8 + action_history_frame_size: int = 49 + + head_hidden_dims: Tuple[int, ...] = (512, 256) + head_dropout: float = 0.0 + +# Training configuration +@dataclass +class TrainerConfig: + """ + Configuration for the DMCTrainer. + """ + + xpid: str = 'dmc_agent_bauernskat_v1' + savedir: str = 'experiments/dmc_agent_result' + load_model: bool = True + save_every_frames: int = 8_192_000 + seed: int = 21000 + + # Logging + log_to_tensorboard: bool = True + log_p0_p1_payoffs: bool = True + log_every_frames: int = 16_384 + log_interval_seconds: float = 5.0 + + # Pipeline & Threading + cuda: str = '0' + training_device: str = "0" + num_actors: int = 10 + num_threads: int = 2 + actor_queue_size_multiplier: int = 256 + actor_game_batch_size: int = 1 + process_join_timeout: float = 5.0 + sample_queue_put_timeout: float = 5.0 + + # Training Hyperparameters + gradient_clip_norm: float = 10.0 + learning_rate: float = 3e-4 + lr_gamma: float = 0.99999424355383592644 + weight_decay: float = 1e-7 + cosine_T0: int = 51_200_000 + cosine_T_mult: int = 2 + cosine_eta_min: float = 3e-5 + batch_size: int = 2_048 + gamma: float = 1.0 + replay_buffer_size: int = 16_384 + min_buffer_size_to_learn: int = 8_192 + + # Prioritized Experience Replay + per_alpha: float = 0.6 + per_beta: float = 0.4 + + # Reward Function + reward_type: Literal['game_score', 'binary', 'hybrid'] = 'hybrid' + + # Parameters for 'hybrid' reward function + max_reward_abs: float = 480.0 + reward_shaping_steepness: float = 0.009 + reward_shaping_threshold: int = 18 + reward_shaping_score_weight: float = 0.5 + reward_shaping_win_bonus: float = 1.0 + + # Agent Hyperparameters + epsilon_start: float = 1.0 + epsilon_end: float = 0.05 + epsilon_decay_frames: int = 20_000_000 + epsilon_decay_type: Literal['linear', 'exponential'] = 'exponential' + epsilon_gamma: float = 0.99999997751381775046 + + # Rule-Based Trump Selection + use_rule_based_trump_decay: bool = False + trump_start: float = 1.0 + trump_end: float = 0.0 + trump_decay_frames: int = 1_024_000_000 + + # Teacher Forcing + use_teacher_forcing: bool = False + teacher_start: float = 1.0 + teacher_end: float = 0.0 + teacher_decay_frames: int = 64_000_000 + + # Environment and Evaluation + env: str = 'bauernskat' + information_level: Literal['normal', 'show_self', 'perfect'] = 'normal' + total_frames: int = 1_024_000_000 + eval_every: int = 8_192_000 + num_eval_games: int = 512 + + # System Configuration + model_config: BauernskatNetConfig = field(default_factory=BauernskatNetConfig) + device: torch.device = field(init=False) + + def __post_init__(self): + """ + Sets the device and validate some hyperparameters. + """ + + if self.training_device != "cpu" and torch.cuda.is_available(): + self.device = torch.device(f"cuda:{self.training_device}") + else: + self.device = torch.device("cpu") + + if self.min_buffer_size_to_learn > self.replay_buffer_size: + raise ValueError("min_buffer_size_to_learn cannot be larger than replay_buffer_size") + if self.batch_size > self.min_buffer_size_to_learn: + raise ValueError("batch_size cannot be larger than min_buffer_size_to_learn") + if self.num_actors <= 0: + raise ValueError("num_actors must be a positive integer") \ No newline at end of file diff --git a/rlcard/agents/bauernskat/dmc_agent/model.py b/rlcard/agents/bauernskat/dmc_agent/model.py new file mode 100644 index 000000000..b89d8d69a --- /dev/null +++ b/rlcard/agents/bauernskat/dmc_agent/model.py @@ -0,0 +1,313 @@ +''' + File name: rlcard/games/bauernskat/dmc_agent/model.py + Author: Oliver Czerwinski + Date created: 08/13/2025 + Date last modified: 12/26/2025 + Python Version: 3.9+ +''' + +import torch +import torch.nn as nn +from typing import Dict, List +from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence + +from rlcard.agents.bauernskat.dmc_agent.config import BauernskatNetConfig + + +class ResidualBlock(nn.Module): + """ + Basic residual block. + """ + + def __init__(self, dim: int): + """ + Initializes ResidualBlock. + """ + + super().__init__() + self.layers = nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, dim * 4), + nn.GELU(), + nn.Linear(dim * 4, dim) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Outputs the result of the residual block. + """ + + return x + self.layers(x) + + +class LayoutProcessor(nn.Module): + """ + Processes a (8, 2) layout tensor. + """ + + def __init__(self, shared_card_embedding: nn.Embedding, output_dim: int, hidden_dim: int): + """ + Initializes LayoutProcessor. + """ + + super().__init__() + self.embedding = shared_card_embedding + embedding_dim = self.embedding.embedding_dim + + input_size = 8 * 2 * embedding_dim + self.mlp = nn.Sequential( + nn.Linear(input_size, hidden_dim), + nn.GELU(), + nn.LayerNorm(hidden_dim), + nn.Linear(hidden_dim, output_dim) + ) + + def forward(self, layout_tensor: torch.Tensor) -> torch.Tensor: + """ + Outputs an embedding for the layout tensor. + """ + + embedded = self.embedding(layout_tensor) + flattened = embedded.view(embedded.shape[0], -1) + return self.mlp(flattened) + + +class CardSetProcessor(nn.Module): + """ + Processes a flexible sized set of cards. + """ + def __init__(self, shared_card_embedding: nn.Embedding, output_dim: int, pool_type: str = 'mean'): + """ + Initializes CardSetProcessor. + """ + + super().__init__() + self.embedding = shared_card_embedding + self.pool_type = pool_type + self.padding_idx = shared_card_embedding.padding_idx + + self.mlp = nn.Sequential( + nn.Linear(self.embedding.embedding_dim, output_dim), + nn.GELU(), + nn.LayerNorm(output_dim) + ) + + def forward(self, ids: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Outputs an embedding for the set of cards. + """ + + if ids.shape[1] == 0: + return torch.zeros(ids.shape[0], self.mlp[0].out_features, device=ids.device) + + safe_ids = ids.clone() + if self.padding_idx is not None: + safe_ids[ids == -1] = self.padding_idx + + embedded = self.embedding(safe_ids) + + if self.pool_type == 'mean': + num_cards = mask.sum(dim=1, keepdim=True).clamp(min=1) + pooled = embedded.sum(dim=1) / num_cards + elif self.pool_type == 'sum': + pooled = embedded.sum(dim=1) + else: + raise ValueError(f"Unknown pool_type: {self.pool_type}") + + return self.mlp(pooled) + + +class BauernskatNet(nn.Module): + """ + Action-in Q-network for Bauernskat. + """ + + def __init__(self, config: BauernskatNetConfig): + """ + Initializes BauernskatNet. + """ + + super().__init__() + self.config = config + + self.card_embedding = nn.Embedding(33, config.card_embedding_dim, padding_idx=32) + self.trump_action_embedding = nn.Embedding(6, config.card_embedding_dim, padding_idx=5) + + card_set_args = (self.card_embedding, config.branch_output_dim, config.pool_type) + layout_proc_args = (self.card_embedding, config.branch_output_dim, config.layout_processor_hidden_dim) + + self.my_layout_processor = LayoutProcessor(*layout_proc_args) + self.opponent_layout_processor = LayoutProcessor(*layout_proc_args) + self.unaccounted_mask_processor = self._build_mlp( + input_dim=32, + hidden_dims=list(config.mask_processor_hidden_dims), + output_dim=config.branch_output_dim + ) + self.trick_processor = CardSetProcessor(*card_set_args) + self.cemetery_processor = CardSetProcessor(*card_set_args) + + self.indicator_processor = self._build_mlp( + input_dim=config.indicator_vector_dim, + hidden_dims=list(config.indicator_mlp_dims), + output_dim=config.branch_output_dim + ) + + self.context_processor = self._build_mlp( + input_dim=config.context_vector_dim, + hidden_dims=list(config.mlp_hidden_dims), + output_dim=config.branch_output_dim + ) + + lstm_out_dim = config.lstm_hidden_dim * (2 if config.use_bidirectional else 1) + self.lstm = nn.LSTM(config.action_history_frame_size, config.lstm_hidden_dim, config.num_lstm_layers, + bidirectional=config.use_bidirectional, batch_first=True) + self.attn = nn.MultiheadAttention(lstm_out_dim, config.attn_heads, batch_first=True) if config.use_attention else None + + self.history_processor = self._build_mlp( + input_dim=lstm_out_dim, + hidden_dims=list(config.lstm_fc_dims), + output_dim=config.branch_output_dim + ) + + self.action_card_processor = CardSetProcessor(*card_set_args) + self.trump_action_processor = nn.Sequential( + nn.Linear(config.card_embedding_dim, config.branch_output_dim), + nn.GELU(), + nn.LayerNorm(config.branch_output_dim) + ) + + concat_dim = config.branch_output_dim * 9 + head_layers = [] + all_dims = [concat_dim] + list(config.head_hidden_dims) + + for i in range(len(config.head_hidden_dims)): + head_layers.extend([ + nn.Linear(all_dims[i], all_dims[i+1]), + ResidualBlock(all_dims[i+1]), + nn.Dropout(p=config.head_dropout) + ]) + + head_layers.extend([ + nn.LayerNorm(config.head_hidden_dims[-1]), + nn.GELU(), + nn.Linear(config.head_hidden_dims[-1], 1) + ]) + self.head = nn.Sequential(*head_layers) + + @staticmethod + def _build_mlp(input_dim: int, hidden_dims: List[int], output_dim: int) -> nn.Sequential: + """ + Creates an MLP with specific dimensions. + """ + + layers = [] + current_dim = input_dim + for hidden_dim in hidden_dims: + layers.extend([nn.Linear(current_dim, hidden_dim), nn.GELU()]) + current_dim = hidden_dim + layers.append(nn.Linear(current_dim, output_dim)) + return nn.Sequential(*layers) + + def _forward_history(self, x: torch.Tensor) -> torch.Tensor: + """ + Processes the action history using LSTM and attention. + """ + + B, _, _ = x.shape + lengths = torch.sum(x.abs().sum(dim=-1) > 0, dim=-1) + full_batch_summary = torch.zeros(B, self.config.branch_output_dim, device=x.device) + + non_empty_mask = lengths > 0 + if not non_empty_mask.any(): + return full_batch_summary + + non_empty_x = x[non_empty_mask] + non_empty_lengths = lengths[non_empty_mask] + non_empty_indices = non_empty_mask.nonzero(as_tuple=True)[0] + + sorted_lengths, sorted_indices = torch.sort(non_empty_lengths, descending=True) + sorted_x = non_empty_x.index_select(0, sorted_indices) + + packed_input = pack_padded_sequence(sorted_x, sorted_lengths.cpu(), batch_first=True, enforce_sorted=True) + packed_output, _ = self.lstm(packed_input) + lstm_out, _ = pad_packed_sequence(packed_output, batch_first=True, total_length=sorted_x.size(1)) + + _, unsorted_indices = torch.sort(sorted_indices) + lstm_out = lstm_out.index_select(0, unsorted_indices) + + if self.attn: + b_non_empty, s_non_empty, _ = lstm_out.shape + indices = torch.arange(s_non_empty, device=x.device).expand(b_non_empty, -1) + attn_mask = indices >= non_empty_lengths.unsqueeze(1) + last_seq_idxs = (non_empty_lengths - 1).clamp(min=0) + query = lstm_out[torch.arange(b_non_empty), last_seq_idxs, :].unsqueeze(1) + attn_out, _ = self.attn(query=query, key=lstm_out, value=lstm_out, key_padding_mask=attn_mask) + summary = attn_out.squeeze(1) + else: + b_non_empty = lstm_out.shape[0] + last_seq_idxs = (non_empty_lengths - 1).clamp(min=0) + summary = lstm_out[torch.arange(b_non_empty), last_seq_idxs, :] + + processed_summary = self.history_processor(summary) + full_batch_summary.index_add_(0, non_empty_indices, processed_summary) + + return full_batch_summary + + def encode_state(self, state_obs: Dict[str, torch.Tensor]) -> torch.Tensor: + """ + Encodes the state observation into a fixed-size vector. + """ + + my_layout_vec = self.my_layout_processor(state_obs['my_layout_tensor']) + opp_layout_vec = self.opponent_layout_processor(state_obs['opponent_layout_tensor']) + unaccounted_mask_vec = self.unaccounted_mask_processor(state_obs['unaccounted_cards_mask']) + trick_vec = self.trick_processor(state_obs['trick_card_ids'], state_obs['trick_card_ids'] != -1) + cemetery_vec = self.cemetery_processor(state_obs['cemetery_card_ids'], state_obs['cemetery_card_ids'] != -1) + + my_hidden_vec = self.indicator_processor(state_obs['my_hidden_indicators']) + opp_hidden_vec = self.indicator_processor(state_obs['opponent_hidden_indicators']) + indicator_vec = my_hidden_vec + opp_hidden_vec + + context_vec = self.context_processor(state_obs['context']) + history_vec = self._forward_history(state_obs['action_history']) + + state_encoding = torch.cat([ + my_layout_vec, opp_layout_vec, unaccounted_mask_vec, + trick_vec, cemetery_vec, + indicator_vec, context_vec, history_vec + ], dim=-1) + + return state_encoding + + def predict_q(self, state_encoding: torch.Tensor, action_obs: Dict[str, torch.Tensor]) -> torch.Tensor: + """ + Predicts Q-values for given state encodings and actions. + """ + + action_cards_vec = self.action_card_processor(action_obs['action_card_ids'], action_obs['action_card_ids'] != -1) + + trump_action_id = action_obs['trump_action_id'] + safe_trump_id = trump_action_id.clone() + safe_trump_id[trump_action_id == -1] = self.trump_action_embedding.padding_idx + trump_action_embedded = self.trump_action_embedding(safe_trump_id) + processed_trump_vec = self.trump_action_processor(trump_action_embedded.squeeze(1)) + + is_card_play_action = (action_obs['action_card_ids'][:, 0] != -1).unsqueeze(-1).float() + is_trump_action = (action_obs['trump_action_id'][:, 0] != -1).unsqueeze(-1).float() + + masked_card_vec = action_cards_vec * is_card_play_action + masked_trump_vec = processed_trump_vec * is_trump_action + action_vec = masked_card_vec + masked_trump_vec + + fused = torch.cat([state_encoding, action_vec], dim=-1) + q_value = self.head(fused).squeeze(-1) + + return q_value + + def forward(self, state_obs: Dict[str, torch.Tensor], action_obs: Dict[str, torch.Tensor]) -> torch.Tensor: + """ + Forward pass to get Q-values. + """ + + state_encoding = self.encode_state(state_obs) + return self.predict_q(state_encoding, action_obs) \ No newline at end of file diff --git a/rlcard/agents/bauernskat/dmc_agent/reward.py b/rlcard/agents/bauernskat/dmc_agent/reward.py new file mode 100644 index 000000000..684bfd27a --- /dev/null +++ b/rlcard/agents/bauernskat/dmc_agent/reward.py @@ -0,0 +1,65 @@ +''' + File name: rlcard/games/bauernskat/dmc_agent/reward.py + Author: Oliver Czerwinski + Date created: 09/07/2025 + Date last modified: 12/26/2025 + Python Version: 3.9+ +''' + +import numpy as np + +def _custom_centered_tanh(final_score: float, steepness: float, win_loss_threshold: int) -> float: + """ + Centered tanh function to compress score magnitudes. + """ + + if final_score >= win_loss_threshold: + adjusted_magnitude = float(final_score - win_loss_threshold) + return np.tanh(adjusted_magnitude * steepness) + elif final_score <= -win_loss_threshold: + adjusted_magnitude = float(abs(final_score) - win_loss_threshold) + return -np.tanh(adjusted_magnitude * steepness) + else: + return 0.0 + +def calculate_game_score_reward(final_score: float) -> float: + """ + Returns the raw game score as reward. + """ + + return float(final_score) + +def calculate_binary_reward(final_score: float) -> float: + """ + Returns +1.0 for win or -1.0 for loss as reward. + """ + + return float(np.sign(final_score)) + +def calculate_hybrid_reward(my_final_pips: int, opponent_final_pips: int, final_score: float, steepness: float = 0.009, threshold: int = 18, score_weight: float = 0.5, win_bonus_magnitude: float = 1.0) -> float: + """ + Calculates a hybrid reward based on game outcome, pip difference and score magnitude. + """ + + # Sign of the outcome + outcome_sign = 0.0 + if final_score >= threshold: + outcome_sign = 1.0 + elif final_score <= -threshold: + outcome_sign = -1.0 + + # Safety for draw + if outcome_sign == 0.0: + return 0.0 + + # Pip difference + r_base = float(my_final_pips - opponent_final_pips) + + # Score multiplier + compressed_score = _custom_centered_tanh(final_score, steepness=steepness, win_loss_threshold=threshold) + m_score = 1.0 + score_weight * abs(compressed_score) + + total_magnitude = win_bonus_magnitude + abs(r_base * m_score) + final_reward = outcome_sign * total_magnitude + + return final_reward \ No newline at end of file diff --git a/rlcard/agents/bauernskat/dmc_agent/trainer.py b/rlcard/agents/bauernskat/dmc_agent/trainer.py new file mode 100644 index 000000000..221f9854d --- /dev/null +++ b/rlcard/agents/bauernskat/dmc_agent/trainer.py @@ -0,0 +1,685 @@ +''' + File name: rlcard/games/bauernskat/dmc_agent/trainer.py + Author: Oliver Czerwinski + Date created: 08/14/2025 + Date last modified: 12/26/2025 + Python Version: 3.9+ +''' + +import os +import threading +import time +import datetime +import pprint +import traceback +import copy +import csv +import json +import logging +from typing import Dict, Any, Literal, get_origin, get_args +import argparse +import queue +import dataclasses +import random + +import numpy as np +import torch +from torch import multiprocessing as mp +from multiprocessing.synchronize import Lock as LockType +from multiprocessing.queues import Queue as QueueType +from multiprocessing.sharedctypes import Synchronized as SynchronizedType +from torch import nn +import torch.nn.functional as F +from torch.utils.tensorboard import SummaryWriter + +from torchrl.data import TensorDictReplayBuffer, LazyTensorStorage +from torchrl.data.replay_buffers.samplers import PrioritizedSampler +from tensordict import TensorDict + +import rlcard + +from rlcard.agents.bauernskat.dmc_agent.config import TrainerConfig +from rlcard.agents.bauernskat.dmc_agent.model import BauernskatNet +from rlcard.agents.bauernskat.dmc_agent.agent import Estimator, AgentDMC_Actor +from rlcard.agents.bauernskat.dmc_agent.utils import ObsPreprocessor, setup_logging, TrainingLogger, AgentEvaluator +from rlcard.agents.bauernskat.dmc_agent.reward import calculate_hybrid_reward, calculate_binary_reward, calculate_game_score_reward + + +log = logging.getLogger('agent_dmc_trainer') + +def format_time(seconds: float) -> str: + """ + Formats seconds into a HH:MM:SS. + """ + + return str(datetime.timedelta(seconds=int(seconds))) + +def gather_metadata(config: TrainerConfig) -> Dict: + """ + Gathers metadata about the training run. + """ + + date_start = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f') + + slurm_data = {k.replace('SLURM_', '').lower(): v for k, v in os.environ.items() if k.startswith('SLURM')} or None + env_whitelist = ('USER', 'HOSTNAME') + safe_env = {k: v for k, v in os.environ.items() if k.startswith('SLURM') or k in env_whitelist} + + def custom_dict_factory(data): + """ + Handles non-serializable types in dataclasses. + """ + + return {k: str(v) if isinstance(v, torch.device) else v for k, v in data} + + config_dict = dataclasses.asdict(config, dict_factory=custom_dict_factory) + + return dict(date_start=date_start, date_end=None, successful=False, + slurm=slurm_data, env=safe_env, config=config_dict) + + +class FileWriter: + """ + Handles logging to files and saving metadata. + """ + + def __init__(self, xpid: str, rootdir: str, config: TrainerConfig): + """ + Initializes FileWriter. + """ + + self.xpid = xpid + self._tick = 0 + self.metadata = gather_metadata(config) + self.metadata['xpid'] = self.xpid + + self._logger = logging.getLogger(f'filewriter/{self.xpid}') + self._logger.setLevel(logging.INFO) + self._logger.propagate = False + + self.basepath = os.path.join(os.path.expandvars(os.path.expanduser(rootdir)), self.xpid) + os.makedirs(self.basepath, exist_ok=True) + + self.paths = { + 'msg': f'{self.basepath}/out.log', 'logs': f'{self.basepath}/logs.csv', + 'fields': f'{self.basepath}/fields.csv', 'meta': f'{self.basepath}/meta.json'} + + self._save_metadata() + + fhandle = logging.FileHandler(self.paths['msg']) + fhandle.setFormatter(logging.Formatter('%(message)s')) + self._logger.addHandler(fhandle) + + self.fieldnames = ['_tick', '_time'] + + if os.path.exists(self.paths['logs']): + with open(self.paths['fields'], 'r') as csvfile: + self.fieldnames = list(csv.reader(csvfile))[0] + + def log(self, to_log: Dict): + """ + Logs values to a CSV file. + """ + + to_log.update({'_tick': self._tick, '_time': time.time()}) + self._tick += 1 + + new_fields = any(k not in self.fieldnames for k in to_log) + + if new_fields: + self.fieldnames.extend(k for k in to_log if k not in self.fieldnames) + with open(self.paths['fields'], 'w') as f: + csv.writer(f).writerow(self.fieldnames) + + if to_log['_tick'] == 1: + with open(self.paths['logs'], 'a') as f: + f.write(f'# {",".join(self.fieldnames)}\n') + + self._logger.info(f'LOG | {", ".join([f"{k}: {v}" for k,v in sorted(to_log.items())])}') + + with open(self.paths['logs'], 'a') as f: + csv.DictWriter(f, fieldnames=self.fieldnames).writerow(to_log) + + def close(self, successful: bool = True): + """ + Closes the FileWriter and saves final metadata. + """ + + self.metadata['date_end'] = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f') + self.metadata['successful'] = successful + + self._save_metadata() + + for handler in self._logger.handlers[:]: + if isinstance(handler, logging.FileHandler): + handler.close() + self._logger.removeHandler(handler) + + def _save_metadata(self): + """ + Saves metadata to a JSON file. + """ + + with open(self.paths['meta'], 'w') as f: + json.dump(self.metadata, f, indent=4, sort_keys=True) + + +def act(actor_id: int, config: TrainerConfig, actor_model: nn.Module, sample_queue: QueueType, + log_queue: QueueType, shared_epsilon: SynchronizedType, shared_trump_prob: SynchronizedType, + shared_teacher_eps: SynchronizedType, dropped_batches_counter: SynchronizedType): + """ + Main loop for an actor process. + """ + + setup_logging() + + seed = config.seed + actor_id + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + log = logging.getLogger('agent_dmc_trainer') + obs_preprocessor = ObsPreprocessor() + + try: + log.info(f'Actor {actor_id} started.') + + env_config = { + 'seed': seed, + 'information_level': config.information_level} + + env = rlcard.make(config.env, config=env_config) + + agent = AgentDMC_Actor(actor_model, 'cpu', use_teacher=config.use_teacher_forcing) + + # Main loop + while True: + aggregated_samples = [] + + for _ in range(config.actor_game_batch_size): + trajectories = {p_id: [] for p_id in range(env.num_players)} + state, player_id = env.reset() + + while not env.is_over(): + action, action_obs = agent.step(state, env, epsilon=shared_epsilon.value, trump_rule_prob=shared_trump_prob.value, teacher_epsilon=shared_teacher_eps.value) + trajectories[player_id].append((state['obs'], action_obs)) + state, player_id = env.step(action) + + # Reward calculation + final_scores = env.get_payoffs() + final_pips = env.get_scores() + + payoffs = np.zeros(2, dtype=np.float32) + + if config.reward_type == 'hybrid': + payoffs[0] = calculate_hybrid_reward( + my_final_pips=final_pips[0], + opponent_final_pips=final_pips[1], + final_score=final_scores[0], + steepness=config.reward_shaping_steepness, + threshold=config.reward_shaping_threshold, + score_weight=config.reward_shaping_score_weight, + win_bonus_magnitude=config.reward_shaping_win_bonus + ) + payoffs[1] = calculate_hybrid_reward( + my_final_pips=final_pips[1], + opponent_final_pips=final_pips[0], + final_score=final_scores[1], + steepness=config.reward_shaping_steepness, + threshold=config.reward_shaping_threshold, + score_weight=config.reward_shaping_score_weight, + win_bonus_magnitude=config.reward_shaping_win_bonus + ) + + elif config.reward_type == 'binary': + payoffs[0] = calculate_binary_reward(final_scores[0]) + payoffs[1] = calculate_binary_reward(final_scores[1]) + + elif config.reward_type == 'game_score': + payoffs[0] = calculate_game_score_reward(final_scores[0]) + payoffs[1] = calculate_game_score_reward(final_scores[1]) + + if config.log_p0_p1_payoffs: + log_queue.put({'p0_payoff': env.get_payoffs()[0], 'p1_payoff': env.get_payoffs()[1]}) + + samples_this_game = [] + + # Monte Carlo learning + for p_id, trajectory in trajectories.items(): + if trajectory: + G = payoffs[p_id] + for s_obs, a_obs in trajectory: + sample = {"observation": s_obs, "action": a_obs, + "next": { "observation": s_obs, "reward": G, "done": True, "action": a_obs }} + samples_this_game.append(obs_preprocessor(sample)) + + if samples_this_game: + aggregated_samples.extend(samples_this_game) + + if aggregated_samples: + try: + sample_queue.put(aggregated_samples, timeout=config.sample_queue_put_timeout) + except queue.Full: + with dropped_batches_counter.get_lock(): + dropped_batches_counter.value += 1 + + except KeyboardInterrupt: + log.info(f"Actor {actor_id} interrupted.") + except Exception as e: + log.error(f'Exception in actor process {actor_id}: {e}\n{traceback.format_exc()}') + raise e + + +def learn(config: TrainerConfig, learner_estimator: Estimator, actor_model: nn.Module, replay_buffer: Any, + frames_counter: SynchronizedType, learner_lock: LockType, buffer_lock: LockType, log_queue: QueueType, + latest_loss: SynchronizedType, latest_mean_q: SynchronizedType, latest_lr: SynchronizedType): + """ + Main loop for a learner thread. + """ + + device = learner_estimator.device + last_log_frame = 0 + + while frames_counter.value < config.total_frames: + with buffer_lock: + if len(replay_buffer) < config.min_buffer_size_to_learn: + time.sleep(1) + continue + try: + batch = replay_buffer.sample(config.batch_size) + except Exception: + time.sleep(0.1) + continue + + is_weights = torch.from_numpy(batch.get("_weight").cpu().numpy().copy()).to(device).float() + indices = torch.from_numpy(batch.get("index").cpu().numpy().copy()) + rewards = torch.from_numpy(batch.get(("next", "reward")).cpu().numpy().copy()).to(device).float() + + state_batch = batch.get("observation").to(device) + action_batch = batch.get("action").to(device) + + targets = rewards.clone() + + mean_q = float(targets.mean().item()) + + # Training step + with learner_lock: + learner_estimator.qnet.train() + predicted_q = learner_estimator.qnet(state_batch, action_batch).clone() + + td_errors = predicted_q - targets + squared_errors = td_errors * td_errors + + with torch.no_grad(): + new_priorities = td_errors.abs().cpu().numpy().copy() + + weighted_loss = (squared_errors * is_weights).mean() + + learner_estimator.optimizer.zero_grad() + weighted_loss.backward() + torch.nn.utils.clip_grad_norm_(learner_estimator.qnet.parameters(), config.gradient_clip_norm) + learner_estimator.optimizer.step() + learner_estimator.scheduler.step() + + # Sync actor model + with torch.no_grad(): + for p_learner, p_actor in zip(learner_estimator.qnet.parameters(), actor_model.parameters()): + p_actor.data.copy_(p_learner.data) + + frames_counter.value += config.batch_size + + if frames_counter.value - last_log_frame >= config.log_every_frames: + latest_loss.value = weighted_loss.item() + latest_mean_q.value = mean_q + latest_lr.value = learner_estimator.scheduler.get_last_lr()[0] + + log_queue.put({ + 'type': 'train_stats', 'frames': frames_counter.value, 'loss': latest_loss.value, + 'mean_q': mean_q, 'learning_rate': latest_lr.value + }) + last_log_frame = frames_counter.value + + with buffer_lock: + replay_buffer.update_priority(indices, torch.from_numpy(new_priorities)) + + +class DMCTrainer: + """ + Trainer for the DMC agent. + """ + + def __init__(self, config: TrainerConfig): + """ + Initializes DMCTrainer. + """ + + self.config = config + self.plogger = FileWriter(xpid=config.xpid, rootdir=config.savedir, config=self.config) + self.writer = None + + if config.log_to_tensorboard: + tb_dir = os.path.join(config.savedir, config.xpid, 'tensorboard_logs') + self.writer = SummaryWriter(log_dir=tb_dir) + log.info(f"TensorBoard logging to {tb_dir}") + + self.checkpointpath = os.path.join(os.path.expandvars( + os.path.expanduser(config.savedir)), config.xpid, "model.tar") + + self.shutdown_event = threading.Event() + self.evaluator = AgentEvaluator(self.config) + + self.actor_processes = [] + self.learner_threads = [] + self.logger = None + self.ingest_thread = None + self.eval_thread = None + + def _setup_components(self): + """ + Sets up multiprocessing components. + """ + + cfg = self.config + self.ctx = mp.get_context('spawn') + + log.info(f"Using learner device: {cfg.device}") + + # T_0 conversion + t0_in_steps = cfg.cosine_T0 // cfg.batch_size + log.info(f"Scheduler T_0 (frames): {cfg.cosine_T0}, Batch Size: {cfg.batch_size} -> T_0 (steps): {t0_in_steps}") + + self.learner_estimator = Estimator( + cfg.model_config, + cfg.learning_rate, + cfg.lr_gamma, + device=cfg.device, + weight_decay=cfg.weight_decay, + cosine_T0=t0_in_steps, + cosine_T_mult=cfg.cosine_T_mult, + cosine_eta_min=cfg.cosine_eta_min + ) + + self.actor_model = BauernskatNet(cfg.model_config).to('cpu') + self.actor_model.share_memory() + self.actor_model.eval() + + self.sample_queue = self.ctx.Queue(maxsize=cfg.num_actors * cfg.actor_queue_size_multiplier) + + sampler = PrioritizedSampler(max_capacity=cfg.replay_buffer_size, alpha=cfg.per_alpha, beta=cfg.per_beta) + self.replay_buffer = TensorDictReplayBuffer( + storage=LazyTensorStorage(max_size=cfg.replay_buffer_size), + sampler=sampler, batch_size=cfg.batch_size) + + log.info("Seeding the replay buffer...") + try: + obs_preprocessor = ObsPreprocessor() + temp_env = rlcard.make(cfg.env, config={'seed': cfg.seed + 999}) + temp_agent = AgentDMC_Actor(self.actor_model, 'cpu') + + state, _ = temp_env.reset() + _, action_obs = temp_agent.step(state, temp_env) + + seed_sample = {"observation": state['obs'], "action": action_obs, "next": { + "observation": state['obs'], "reward": 0.0, "done": False, "action": action_obs}} + self.replay_buffer.add(TensorDict(obs_preprocessor(seed_sample), batch_size=[])) + self.replay_buffer.empty() + + log.info("Replay buffer seeded successfully.") + except Exception as e: + log.error(f"Failed to seed replay buffer: {e}") + raise + + # Shared variables + self.log_queue = self.ctx.Queue() + self.frames = self.ctx.Value('Q', 0) + self.learner_lock = self.ctx.Lock() + self.buffer_lock = self.ctx.Lock() + self.latest_loss = self.ctx.Value('f', 0.0) + self.latest_mean_q = self.ctx.Value('f', 0.0) + self.latest_lr = self.ctx.Value('f', cfg.learning_rate) + self.avg_p0_payoff = self.ctx.Value('f', 0.0) + self.current_epsilon = self.ctx.Value('f', cfg.epsilon_start) + self.dropped_batches_total = self.ctx.Value('Q', 0) + self.total_elapsed_time = self.ctx.Value('d', 0.0) + + # Trump + initial_trump_prob = cfg.trump_start if cfg.use_rule_based_trump_decay else 0.0 + self.current_trump_prob = self.ctx.Value('f', initial_trump_prob) + + # Teacher forcing + self.current_teacher_eps = self.ctx.Value('f', cfg.teacher_start if cfg.use_teacher_forcing else 0.0) + + # Load model + if cfg.load_model and os.path.exists(self.checkpointpath): + checkpoint = torch.load(self.checkpointpath, map_location=cfg.device, weights_only=False) + + self.learner_estimator.qnet.load_state_dict(checkpoint['model_state_dict']) + + self.learner_estimator.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + if 'scheduler_state_dict' in checkpoint: + self.learner_estimator.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) + + self.frames.value = checkpoint.get('frames', 0) + self.current_epsilon.value = checkpoint.get('epsilon', cfg.epsilon_start) + self.total_elapsed_time.value = checkpoint.get('total_elapsed_time', 0.0) + + log.info(f"Resuming job from {self.frames.value} frames with epsilon {self.current_epsilon.value:.4f} after {format_time(self.total_elapsed_time.value)} of training.") + + # Initial sync of actor model + with torch.no_grad(): + for p_learner, p_actor in zip(self.learner_estimator.qnet.parameters(), self.actor_model.parameters()): + p_actor.data.copy_(p_learner.data) + + def _sample_ingest_worker(self): + """ + Ingests samples from the sample queue into the replay buffer. + """ + + log.info("Sample ingest worker started.") + + while not self.shutdown_event.is_set(): + try: + sample_batch = self.sample_queue.get(timeout=1.0) + if sample_batch is None: break + + with self.buffer_lock: + for sample in sample_batch: + self.replay_buffer.add(TensorDict(sample, batch_size=[])) + + except queue.Empty: + continue + except (KeyboardInterrupt, EOFError): + break + + log.info("Sample ingest worker terminated.") + + def start(self): + """ + Starts the training process. + """ + + cfg = self.config + self._setup_components() + + self.logger = TrainingLogger(self) + self.logger.start() + + self.ingest_thread = threading.Thread(target=self._sample_ingest_worker, daemon=True) + self.ingest_thread.start() + + self.actor_processes = [self.ctx.Process(target=act, args=(i, cfg, self.actor_model, self.sample_queue, + self.log_queue, self.current_epsilon, self.current_trump_prob, self.current_teacher_eps, + self.dropped_batches_total)) + for i in range(cfg.num_actors)] + for p in self.actor_processes: p.start() + + self.learner_threads = [threading.Thread(target=learn, args=(cfg, self.learner_estimator, self.actor_model, + self.replay_buffer, self.frames, self.learner_lock, self.buffer_lock, self.log_queue, + self.latest_loss, self.latest_mean_q, self.latest_lr)) for _ in range(cfg.num_threads)] + for t in self.learner_threads: t.start() + + try: + last_checkpoint_frame, last_eval_frame = 0, 0 + + resumed_time = self.total_elapsed_time.value + start_time = time.time() + + while self.frames.value < cfg.total_frames: + time.sleep(1) + current_frames = self.frames.value + + # Epsilon Decay + if cfg.epsilon_decay_type == 'exponential': + eps = max(cfg.epsilon_end, cfg.epsilon_start * (cfg.epsilon_gamma ** current_frames)) + else: + ratio = min(1.0, current_frames / cfg.epsilon_decay_frames) + eps = cfg.epsilon_start - (cfg.epsilon_start - cfg.epsilon_end) * ratio + self.current_epsilon.value = eps + + # Trump Decay + if cfg.use_rule_based_trump_decay: + trump_ratio = min(1.0, current_frames / cfg.trump_decay_frames) + self.current_trump_prob.value = cfg.trump_start - (cfg.trump_start - cfg.trump_end) * trump_ratio + else: + self.current_trump_prob.value = 0.0 + + # Teacher Forcing Decay + if cfg.use_teacher_forcing: + t_ratio = min(1.0, current_frames / cfg.teacher_decay_frames) + t_eps = cfg.teacher_start - (cfg.teacher_start - cfg.teacher_end) * t_ratio + self.current_teacher_eps.value = t_eps + + with self.buffer_lock: mem_size = len(self.replay_buffer) + + current_session_duration = time.time() - start_time + self.total_elapsed_time.value = resumed_time + current_session_duration + + status_text = (f"\rTime: {format_time(self.total_elapsed_time.value)} | " + f"Step: {current_frames/1e6:.2f}M/{cfg.total_frames/1e6:.1f}M | " + f"Mem: {mem_size/1e3:.1f}k | Teacher-ε: {self.current_teacher_eps.value:.4f} | " + f"Trump-ε: {self.current_trump_prob.value:.4f} | Random-ε: {eps:.4f} | " + f"LR: {self.latest_lr.value:.3e} | " + f"ØQ: {self.latest_mean_q.value:+.4f} | Loss: {self.latest_loss.value:.4f} | " + f"ØPayoff: {self.avg_p0_payoff.value:+.2f}") + + print(status_text, end="", flush=True) + + if current_frames - last_checkpoint_frame >= cfg.save_every_frames: + self.checkpoint() + last_checkpoint_frame = current_frames + + # Evaluation + if current_frames - last_eval_frame >= cfg.eval_every: + if self.eval_thread is not None and self.eval_thread.is_alive(): + pass + else: + with self.learner_lock: + eval_net_copy = copy.deepcopy(self.learner_estimator.qnet).to(cfg.device) + + def run_eval_thread(): + """ + Runs the evaluation in a separate thread. + """ + + try: + self.evaluator.evaluate(eval_net_copy, current_frames, self.writer) + except Exception as e: + log.error(f"Error in evaluation thread: {e}\n{traceback.format_exc()}") + + self.eval_thread = threading.Thread(target=run_eval_thread, daemon=True) + self.eval_thread.start() + + last_eval_frame = current_frames + + except KeyboardInterrupt: + print("\nTraining interrupted by user.") + finally: + print("\nTerminating processes and saving final model...") + self.shutdown_event.set() + + for p in self.actor_processes: + if p.is_alive(): p.terminate(); p.join(timeout=cfg.process_join_timeout) + if self.ingest_thread.is_alive(): self.ingest_thread.join(timeout=cfg.process_join_timeout) + + if self.eval_thread is not None and self.eval_thread.is_alive(): + print("Waiting for final evaluation to complete...") + self.eval_thread.join(timeout=120) + + if self.logger: self.logger.stop() + + self.checkpoint() + self.plogger.close() + if self.writer: self.writer.close() + log.info("Trainer shutdown complete.") + + def checkpoint(self): + """ + Saves the current model checkpoint. + """ + + cfg = self.config + log.info(f"Saving full training checkpoint to {self.checkpointpath}") + + checkpoint = { + 'model_state_dict': self.learner_estimator.qnet.state_dict(), + 'optimizer_state_dict': self.learner_estimator.optimizer.state_dict(), + 'scheduler_state_dict': self.learner_estimator.scheduler.state_dict(), + 'frames': self.frames.value, + 'epsilon': self.current_epsilon.value, + 'total_elapsed_time': self.total_elapsed_time.value, + 'config': cfg + } + + torch.save(checkpoint, self.checkpointpath) + + checkpoint_dir = os.path.dirname(self.checkpointpath) + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + backup_path = os.path.join(checkpoint_dir, f"model_{timestamp}_frame{self.frames.value}.tar") + torch.save(checkpoint, backup_path) + + log.info(f"Backup saved to {backup_path}") + + inference_path = os.path.join(checkpoint_dir, f"inference_model_{self.frames.value}.pt") + torch.save(self.learner_estimator.qnet.state_dict(), inference_path) + + log.info(f"Inference checkpoint saved to {inference_path}") + +def main(): + """ + Main function to run the DMC trainer. + """ + + try: + mp.set_start_method('spawn', force=True) + except RuntimeError: + pass + + setup_logging() + parser = argparse.ArgumentParser("Agent DMC Trainer for RLCard") + + for field in dataclasses.fields(TrainerConfig): + if not field.init or field.name == "model_config": continue + + if field.type == bool: + if field.default: + parser.add_argument(f'--no-{field.name}', dest=field.name, action='store_false', help=f"Disable {field.name}") + else: + parser.add_argument(f'--{field.name}', dest=field.name, action='store_true', help=f"Enable {field.name}") + parser.set_defaults(**{field.name: field.default}) + else: + kwargs = {'type': field.type, 'default': field.default, 'help': f"Set {field.name} (default: {field.default})"} + if get_origin(field.type) is Literal: + kwargs['choices'] = get_args(field.type) + kwargs['type'] = type(kwargs['choices'][0]) + parser.add_argument(f'--{field.name}', **kwargs) + + args = parser.parse_args() + config = TrainerConfig(**vars(args)) + os.environ["CUDA_VISIBLE_DEVICES"] = config.cuda + + trainer = DMCTrainer(config) + log.info(f"Starting training for {config.xpid} with config:\n{pprint.pformat(dataclasses.asdict(config))}") + trainer.start() + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/rlcard/agents/bauernskat/dmc_agent/utils.py b/rlcard/agents/bauernskat/dmc_agent/utils.py new file mode 100644 index 000000000..81e196de6 --- /dev/null +++ b/rlcard/agents/bauernskat/dmc_agent/utils.py @@ -0,0 +1,337 @@ +''' + File name: rlcard/games/bauernskat/dmc_agent/utils.py + Author: Oliver Czerwinski + Date created: 08/15/2025 + Date last modified: 12/26/2025 + Python Version: 3.9+ +''' + +import logging +import time +import queue +import threading +import numpy as np +import torch +from typing import Dict, Any + +import rlcard +from rlcard.agents.bauernskat import rule_agents as bauernskat_rule_agents +from rlcard.agents.bauernskat.dmc_agent.agent import AgentDMC_Actor + +from rlcard.agents.bauernskat.dmc_agent.config import MAX_TRICK_SIZE, MAX_CEMETERY_SIZE + +def setup_logging(level=logging.INFO): + """ + Prepares the logging. + """ + + logger = logging.getLogger('agent_dmc_trainer') + if logger.hasHandlers(): + return + + logger.setLevel(level) + shandle = logging.StreamHandler() + shandle.setFormatter( + logging.Formatter( + '[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] ' + '%(message)s')) + logger.addHandler(shandle) + logger.propagate = False + +class ObsPreprocessor: + """ + Preprocesses observations for the DMC agent. + """ + + def __init__(self): + """ + Initialized ObsPreprocessor. + """ + + self.pad_keys = { + 'trick_card_ids': MAX_TRICK_SIZE, + 'cemetery_card_ids': MAX_CEMETERY_SIZE, + } + + def _pad_obs(self, obs_dict: Dict) -> Dict: + """ + Pads observation of different lengths to a fixed size. + """ + + padded_dict = obs_dict.copy() + + for key, max_len in self.pad_keys.items(): + if key in padded_dict: + original_list = padded_dict[key] + padding_needed = max_len - len(original_list) + if padding_needed > 0: + padded_dict[key] = original_list + [-1] * padding_needed + + return padded_dict + + def _prepare_for_tensordict(self, data: Dict) -> Dict: + """ + Converts a list to a numpy array for tensordict compatibility. + """ + + for key, value in data.items(): + if isinstance(value, dict): + self._prepare_for_tensordict(value) + elif isinstance(value, list): + data[key] = np.array(value, dtype=np.int32) + elif isinstance(value, np.ndarray) and value.dtype == np.float64: + data[key] = value.astype(np.float32) + + return data + + def __call__(self, obs_or_sample: Dict[str, Any]) -> Dict[str, Any]: + """ + Pads and converts for tensordict compatibility. + """ + + if 'observation' in obs_or_sample: + obs_or_sample['observation'] = self._pad_obs(obs_or_sample['observation']) + if 'next' in obs_or_sample and 'observation' in obs_or_sample['next']: + obs_or_sample['next']['observation'] = self._pad_obs(obs_or_sample['next']['observation']) + return self._prepare_for_tensordict(obs_or_sample) + else: + padded_obs = self._pad_obs(obs_or_sample) + return self._prepare_for_tensordict(padded_obs) + + +class TrainingLogger: + """ + Handles logging during training. + """ + + def __init__(self, trainer_instance): + """ + Initializes TrainingLogger. + """ + + self.config = trainer_instance.config + self.plogger = trainer_instance.plogger + self.writer = trainer_instance.writer + self.log_queue = trainer_instance.log_queue + self.shutdown_event = trainer_instance.shutdown_event + self.replay_buffer = trainer_instance.replay_buffer + self.buffer_lock = trainer_instance.buffer_lock + self.avg_p0_payoff = trainer_instance.avg_p0_payoff + self.dropped_batches_total = trainer_instance.dropped_batches_total + self.total_elapsed_time = trainer_instance.total_elapsed_time + + self.current_epsilon = trainer_instance.current_epsilon + self.current_teacher_eps = trainer_instance.current_teacher_eps + self.current_trump_prob = trainer_instance.current_trump_prob + + self.thread = None + self.log = logging.getLogger('agent_dmc_trainer') + + def start(self): + """ + Starts a logging thread. + """ + + self.thread = threading.Thread(target=self._log_worker, daemon=True) + self.thread.start() + + def stop(self): + """ + Stops the logging thread. + """ + + if self.thread and self.thread.is_alive(): + self.thread.join(timeout=self.config.process_join_timeout) + + def _log_worker(self): + """ + Processes log records from a queue. + """ + + payoff_buffer = [] + last_log_time = time.time() + + latest_frames, latest_loss, latest_mean_q, latest_lr = 0, 0.0, 0.0, 0.0 + + def perform_log(): + nonlocal payoff_buffer, last_log_time + if latest_frames > 0: + with self.buffer_lock: + buffer_size = len(self.replay_buffer) + + log_data = { + 'Training/frames': latest_frames, + 'Training/loss': latest_loss, + 'Training/mean_q_values': latest_mean_q, + 'Training/learning_rate': latest_lr, + 'Performance/buffer_size': buffer_size, + 'Performance/total_dropped_batches': self.dropped_batches_total.value, + 'Performance/total_training_time_hours': self.total_elapsed_time.value / 3600.0, + 'Exploration/Random-epsilon': self.current_epsilon.value, + } + + if self.config.use_teacher_forcing: + log_data['Exploration/Teacher-epsilon'] = self.current_teacher_eps.value + + if self.config.use_rule_based_trump_decay: + log_data['Exploration/Trump-epsilon'] = self.current_trump_prob.value + + if payoff_buffer: + avg_payoff = np.mean(payoff_buffer) + self.avg_p0_payoff.value = avg_payoff + log_data['Performance/avg_p0_payoff_5s'] = avg_payoff + log_data['Performance/total_games_in_5s'] = len(payoff_buffer) + payoff_buffer = [] + + self.plogger.log(log_data) + if self.writer: + for key, value in log_data.items(): + if key not in ['_tick', '_time']: + self.writer.add_scalar(key, value, latest_frames) + + last_log_time = time.time() + + while not self.shutdown_event.is_set(): + try: + record = self.log_queue.get(timeout=1.0) + if record is None: break + + if record.get('type') == 'train_stats': + latest_frames = record.get('frames', latest_frames) + latest_loss = record.get('loss', latest_loss) + latest_mean_q = record.get('mean_q', latest_mean_q) + latest_lr = record.get('learning_rate', latest_lr) + else: + payoff_buffer.append(record['p0_payoff']) + + except queue.Empty: + continue + except (KeyboardInterrupt, EOFError): + break + + if time.time() - last_log_time >= self.config.log_interval_seconds: + perform_log() + + perform_log() + self.log.info("Log worker terminated.") + + +class AgentEvaluator: + """ + Evaluates the agent against multiple rule-based agents. + """ + + def __init__(self, config): + """ + Initializes AgentEvaluator. + """ + + self.config = config + self.log = logging.getLogger('agent_dmc_trainer') + self.opponents = { + 'Random': bauernskat_rule_agents.BauernskatRandomRuleAgent(), + 'Frugal': bauernskat_rule_agents.BauernskatFrugalRuleAgent(), + 'Lookahead': bauernskat_rule_agents.BauernskatLookaheadRuleAgent(), + 'SHOT': bauernskat_rule_agents.BauernskatSHOTAlphaBetaRuleAgent() + } + + def evaluate(self, eval_net, current_frames, writer): + """ + Evaluates the agent against all opponents. + """ + + self.log.info("Starting Evaluation Run...") + + with torch.no_grad(): + eval_agent = AgentDMC_Actor(eval_net, self.config.device, use_teacher=False) + + eval_env_config = { + 'seed': 500, + 'information_level': self.config.information_level + } + eval_env = rlcard.make(self.config.env, config=eval_env_config) + + total_p0_wins, total_p1_wins, total_p0_payoff, total_p1_payoff = 0, 0, 0.0, 0.0 + total_games_as_p0, total_games_as_p1 = 0, 0 + + win_rates_by_opponent = {} + avg_rewards_by_opponent = {} + + for name, opponent in self.opponents.items(): + games_per_opponent_half = self.config.num_eval_games // (2 * len(self.opponents)) + + eval_env.set_agents([eval_agent, opponent]) + p0_wins, p0_payoff = self._run_half(eval_env, games_per_opponent_half, agent_pos=0) + + eval_env.set_agents([opponent, eval_agent]) + p1_wins, p1_payoff = self._run_half(eval_env, games_per_opponent_half, agent_pos=1) + + # Accumulate stats + total_p0_wins += p0_wins + total_p0_payoff += p0_payoff + total_games_as_p0 += games_per_opponent_half + total_p1_wins += p1_wins + total_p1_payoff += p1_payoff + total_games_as_p1 += games_per_opponent_half + + # Individual Opponent Stats + p0_win_rate = p0_wins / games_per_opponent_half if games_per_opponent_half > 0 else 0 + p0_avg_reward = p0_payoff / games_per_opponent_half if games_per_opponent_half > 0 else 0 + p1_win_rate = p1_wins / games_per_opponent_half if games_per_opponent_half > 0 else 0 + p1_avg_reward = p1_payoff / games_per_opponent_half if games_per_opponent_half > 0 else 0 + self.log.info(f" vs {name}: P0 [WR: {p0_win_rate:.1%}, AvgR: {p0_avg_reward:+.2f}] | P1 [WR: {p1_win_rate:.1%}, AvgR: {p1_avg_reward:+.2f}]") + + # Combined Stats for specific opponent + total_games_this_opponent = games_per_opponent_half * 2 + if total_games_this_opponent > 0: + combined_win_rate = (p0_wins + p1_wins) / total_games_this_opponent + combined_avg_reward = (p0_payoff + p1_payoff) / total_games_this_opponent + win_rates_by_opponent[name] = combined_win_rate + avg_rewards_by_opponent[name] = combined_avg_reward + + # Overall Stats + overall_p0_win_rate = total_p0_wins / total_games_as_p0 if total_games_as_p0 > 0 else 0 + overall_p1_win_rate = total_p1_wins / total_games_as_p1 if total_games_as_p1 > 0 else 0 + overall_p0_avg_reward = total_p0_payoff / total_games_as_p0 if total_games_as_p0 > 0 else 0.0 + overall_p1_avg_reward = total_p1_payoff / total_games_as_p1 if total_games_as_p1 > 0 else 0.0 + + self.log.info(f"Overall Factual -> P0 [WR: {overall_p0_win_rate:.1%}, AvgR: {overall_p0_avg_reward:+.2f}] | P1 [WR: {overall_p1_win_rate:.1%}, AvgR: {overall_p1_avg_reward:+.2f}]") + + if writer: + writer.add_scalar('Evaluation/overall_p0_win_rate', overall_p0_win_rate, current_frames) + writer.add_scalar('Evaluation/overall_p1_win_rate', overall_p1_win_rate, current_frames) + writer.add_scalar('Evaluation/overall_p0_avg_reward', overall_p0_avg_reward, current_frames) + writer.add_scalar('Evaluation/overall_p1_avg_reward', overall_p1_avg_reward, current_frames) + + if win_rates_by_opponent: + writer.add_scalars('Evaluation/Combined_WinRate_vs_Opponent', win_rates_by_opponent, current_frames) + + if avg_rewards_by_opponent: + writer.add_scalars('Evaluation/Combined_AvgR_vs_Opponent', avg_rewards_by_opponent, current_frames) + + def _run_half(self, env, num_games, agent_pos): + """ + Runs games in a specified player role. + """ + + total_wins = 0 + total_payoff = 0.0 + agent = env.agents[agent_pos] + opponent = env.agents[1 - agent_pos] + + for _ in range(num_games): + state, player_id = env.reset() + + while not env.is_over(): + if player_id == agent_pos: + action, _ = agent.eval_step(state, env) + else: + action = opponent.step(state) + state, player_id = env.step(action) + payoffs = env.get_payoffs() + total_payoff += payoffs[agent_pos] + + if payoffs[agent_pos] > 0: + total_wins += 1 + + return total_wins, total_payoff \ No newline at end of file diff --git a/rlcard/agents/bauernskat/rule_agents.py b/rlcard/agents/bauernskat/rule_agents.py new file mode 100644 index 000000000..ab32324f7 --- /dev/null +++ b/rlcard/agents/bauernskat/rule_agents.py @@ -0,0 +1,819 @@ +''' + File name: rlcard/games/bauernskat/rule_agents.py + Author: Oliver Czerwinski + Date created: 08/12/2025 + Date last modified: 12/25/2025 + Python Version: 3.9+ +''' + +import random +from collections import Counter +import math +import random + +from rlcard.games.bauernskat.action_event import ActionEvent, DeclareTrumpAction, PlayCardAction +from rlcard.games.bauernskat.card import BauernskatCard + + +def _get_card_strength(card: BauernskatCard, trump_suit: str, led_suit: str) -> int: + """ + Helper function to calculate the relative strength of a card. + """ + + STRENGTH_MAP = {'7': 0, '8': 1, '9': 2, 'Q': 3, 'K': 4, '10': 5, 'A': 6, 'J': 7} + + if trump_suit != 'G': + if card.rank == 'J': + jack_strength = {'C': 3, 'S': 2, 'H': 1, 'D': 0} + return 400 + jack_strength[card.suit] + if card.suit == trump_suit: + return 300 + STRENGTH_MAP[card.rank] + + if card.suit == led_suit: + STRENGTH_MAP_SUIT = {'7': 0, '8': 1, '9': 2, 'J': 3, 'Q': 4, 'K': 5, '10': 6, 'A': 7} + return 200 + STRENGTH_MAP_SUIT[card.rank] + + return 100 + STRENGTH_MAP[card.rank] + + +def _is_trump(card, ts): + """ + Helper function to determine if a card is a trump. + """ + + if ts == 'G': return False + return card.rank == 'J' or card.suit == ts + + +class BauernskatRandomRuleAgent: + """ + An agent that selects any legal action randomly. + """ + + def __init__(self, seed=None): + """ + Initialized BauernskatRandomRuleAgent. + """ + + self.use_raw = False + self.rng = random.Random(seed) + + def seed(self, seed=None): + """ + Sets a seed. + """ + + self.rng = random.Random(seed) + + def step(self, state): + """ + Selects a random legal action. + """ + + return self.rng.choice(list(state['legal_actions'].keys())) + + def eval_step(self, state): + """ + Selects a random legal action for evaluation. + """ + + action = self.step(state) + return action, {} + + +class BauernskatFrugalRuleAgent: + """ + A simple agent, that plays conservatively to minimize losses. + """ + + def __init__(self, seed=None): + """ + Initialized BauernskatFrugalRuleAgent. + """ + + self.use_raw = True + self.rng = random.Random(seed) + + def seed(self, seed=None): + """ + Sets a seed. + """ + + self.rng = random.Random(seed) + + def step(self, state): + """ + Selects a legal action with a conservative strategy. + """ + + legal_action_ids = list(state['legal_actions'].keys()) + raw_info = state['raw_state_info'] + round_phase = raw_info['round_phase'] + + # Simple rule based trump declaration + if round_phase == 'declare_trump': + my_cards = raw_info['my_cards'] + num_jacks = sum(1 for card in my_cards if card.rank == 'J') + if num_jacks >= 3: + for action_id in legal_action_ids: + action = ActionEvent.from_action_id(action_id) + if isinstance(action, DeclareTrumpAction) and action.trump_suit == 'G': + return action_id + + suit_counts = Counter(card.suit for card in my_cards) + + if not suit_counts: + return self.rng.choice(legal_action_ids) + max_count = max(suit_counts.values()) + best_suits = [suit for suit, count in suit_counts.items() if count == max_count] + + if len(best_suits) == 1: + best_suit_choice = best_suits[0] + else: + best_rank_val = -1 + rank_order = BauernskatCard.ranks + best_suit_choice = best_suits[0] + for suit in best_suits: + cards_of_suit = [card for card in my_cards if card.suit == suit] + max_rank_in_suit = max(rank_order.index(c.rank) for c in cards_of_suit) + if max_rank_in_suit > best_rank_val: + best_rank_val = max_rank_in_suit + best_suit_choice = suit + + for action_id in legal_action_ids: + action = ActionEvent.from_action_id(action_id) + if isinstance(action, DeclareTrumpAction) and action.trump_suit == best_suit_choice: + return action_id + + return self.rng.choice(legal_action_ids) + + # Only play good cards if necessary or really worth it + if round_phase == 'play': + trick_moves = raw_info['trick_moves'] + trump_suit = raw_info['trump_suit'] + legal_play_actions = [a for a in (ActionEvent.from_action_id(aid) for aid in legal_action_ids) if isinstance(a, PlayCardAction)] + + if not trick_moves: + min_points = min(action.card.points for action in legal_play_actions) + min_point_actions = [action for action in legal_play_actions if action.card.points == min_points] + + non_trump_options = [action for action in min_point_actions if not _is_trump(action.card, trump_suit)] + if non_trump_options: + return non_trump_options[0].action_id + else: + return min_point_actions[0].action_id + + else: + led_card = trick_moves[0][1] + led_suit = led_card.suit + led_card_strength = _get_card_strength(led_card, trump_suit, led_suit) + + winning_moves = [action for action in legal_play_actions if _get_card_strength(action.card, trump_suit, led_suit) > led_card_strength] + losing_moves = [action for action in legal_play_actions if action not in winning_moves] + + should_win = False + + if winning_moves: + weakest_winning_move = min(winning_moves, key=lambda a: _get_card_strength(a.card, trump_suit, led_suit)) + potential_trick_value = led_card.points + weakest_winning_move.card.points + + if potential_trick_value >= 10: + should_win = True + + if should_win: + return min(winning_moves, key=lambda a: _get_card_strength(a.card, trump_suit, led_suit)).action_id + else: + if losing_moves: + return min(losing_moves, key=lambda a: a.card.points).action_id + else: + return min(winning_moves, key=lambda a: _get_card_strength(a.card, trump_suit, led_suit)).action_id + + return self.rng.choice(legal_action_ids) + + def eval_step(self, state): + """ + Selects a legal action for evaluation. + """ + + action = self.step(state) + return action, {} + + +class BauernskatLookaheadRuleAgent: + """ + An agent simulating the outcome of each legal move aswell as the opponents likely response. + """ + + def __init__(self, seed=None): + """ + Initialized BauernskatLookaheadRuleAgent. + """ + + self.use_raw = True + self.rng = random.Random(seed) + + def seed(self, seed=None): + """ + Sets a seed. + """ + + self.rng = random.Random(seed) + + def step(self, state): + """ + Selects a legal action using lookahead simulations. + """ + + legal_action_ids = list(state['legal_actions'].keys()) + raw_info = state['raw_state_info'] + round_phase = raw_info['round_phase'] + + # Simple rule based trump declaration + if round_phase == 'declare_trump': + my_cards = raw_info['my_cards'] + num_jacks = sum(1 for card in my_cards if card.rank == 'J') + + if num_jacks >= 2: + for action_id in legal_action_ids: + action = ActionEvent.from_action_id(action_id) + if isinstance(action, DeclareTrumpAction) and action.trump_suit == 'G': + return action_id + + suit_counts = Counter(card.suit for card in my_cards) + if not suit_counts: return self.rng.choice(legal_action_ids) + + max_count = max(suit_counts.values()) + best_suits = [suit for suit, count in suit_counts.items() if count == max_count] + + if len(best_suits) == 1: + best_suit_choice = best_suits[0] + else: + best_rank_val = -1 + rank_order = BauernskatCard.ranks + best_suit_choice = best_suits[0] + for suit in best_suits: + cards_of_suit = [card for card in my_cards if card.suit == suit] + max_rank_in_suit = max(rank_order.index(c.rank) for c in cards_of_suit) + if max_rank_in_suit > best_rank_val: + best_rank_val = max_rank_in_suit + best_suit_choice = suit + + for action_id in legal_action_ids: + action = ActionEvent.from_action_id(action_id) + if isinstance(action, DeclareTrumpAction) and action.trump_suit == best_suit_choice: + return action_id + + return self.rng.choice(legal_action_ids) + + # Go through all legal moves and simulate opponent response to determine the best action + if round_phase == 'play': + legal_play_actions = [a for a in (ActionEvent.from_action_id(aid) for aid in legal_action_ids) if isinstance(a, PlayCardAction)] + + if len(legal_play_actions) == 1: + return legal_play_actions[0].action_id + + best_move = None + best_score = float('-inf') + + trick_moves = raw_info['trick_moves'] + + for action in legal_play_actions: + if not trick_moves: + # If starting a trick + score = self._score_leading_move(action, raw_info) + else: + # If answering a trick + score = self._score_following_move(action, raw_info) + + if score > best_score: + best_score = score + best_move = action + + return best_move.action_id if best_move else self.rng.choice(legal_action_ids) + + return self.rng.choice(legal_action_ids) + + def _score_leading_move(self, action: PlayCardAction, raw_info: dict) -> float: + """ + Rates card actions according to simulated opponent response. + """ + + my_card = action.card + trump_suit = raw_info['trump_suit'] + opponent_visible_cards = raw_info['opponent_visible_cards'] + + # Simulation of response + led_suit = my_card.suit + is_led_trump = _is_trump(my_card, trump_suit) + + legal_replies = [] + if is_led_trump: + trumps_in_hand = [c for c in opponent_visible_cards if _is_trump(c, trump_suit)] + legal_replies = trumps_in_hand if trumps_in_hand else opponent_visible_cards + else: + suit_in_hand = [c for c in opponent_visible_cards if c.suit == led_suit and not _is_trump(c, trump_suit)] + legal_replies = suit_in_hand if suit_in_hand else opponent_visible_cards + + my_strength = _get_card_strength(my_card, trump_suit, led_suit) + + if not legal_replies: + return my_card.points - (1 if my_card.points >= 10 else 0) + + winning_replies = [c for c in legal_replies if _get_card_strength(c, trump_suit, led_suit) > my_strength] + losing_replies = [c for c in legal_replies if c not in winning_replies] + + # Outcome evaluation + if winning_replies: + # Bad if opponent can win + weakest_winner = min(winning_replies, key=lambda c: _get_card_strength(c, trump_suit, led_suit)) + points_lost = my_card.points + weakest_winner.points + return -points_lost + else: + # Good if opponent cannot win + most_frugal_discard = min(losing_replies, key=lambda c: c.points) + points_gained = my_card.points + most_frugal_discard.points + + # Bonus for making the opponent use a trump + if _is_trump(most_frugal_discard, trump_suit): + points_gained += 1 + + return points_gained + + def _score_following_move(self, action: PlayCardAction, raw_info: dict) -> float: + """ + Rates card actions when answering a card in the trick. + """ + + my_card = action.card + trump_suit = raw_info['trump_suit'] + led_card = raw_info['trick_moves'][0][1] + led_suit = led_card.suit + + my_strength = _get_card_strength(my_card, trump_suit, led_suit) + led_strength = _get_card_strength(led_card, trump_suit, led_suit) + + if my_strength > led_strength: + reward = led_card.points + my_card.points + + # Small cost for using a strong card + cost = my_strength / 50.0 + + # Bonus for winning tricks with high pips + if reward >= 10: + reward *= 1.5 + + return reward - cost + else: + # Lost the trick + points_lost = led_card.points + my_card.points + + # Bonus for using low pips cards + discard_bonus = (11 - my_card.points) / 10.0 + + return -points_lost + discard_bonus + + def eval_step(self, state): + """ + Selects a legal action for evaluation. + """ + + action = self.step(state) + return action, {} + + +class BauernskatSHOTAlphaBetaRuleAgent: + """ + A hybrid agent combining Simple Heuristic Search (SHOT), Alpha-Beta search, and PIMC for Bauernskat. + """ + + def __init__(self, num_simulations=16, alpha_beta_depth=2, use_shot=True, use_move_ordering=True, use_alpha_beta=True, seed=None): + """ + Initialized BauernskatSHOTAlphaBetaRuleAgent. + """ + + self.use_raw = True + self.rng = random.Random(seed) + + self.num_simulations = num_simulations + self.alpha_beta_depth = alpha_beta_depth + + self.use_shot = use_shot + self.use_move_ordering = use_move_ordering + self.use_alpha_beta = use_alpha_beta + + self.card_strength_cache = {} + + try: + deck = BauernskatCard.get_deck() + except NameError: + deck = [] + + for trump in ['G', 'C', 'S', 'H', 'D']: + for led in ['C', 'S', 'H', 'D']: + for card in deck: + self.card_strength_cache[(card, trump, led)] = _get_card_strength(card, trump, led) + + def seed(self, seed=None): + """ + Sets a seed. + """ + + self.rng = random.Random(seed) + + def _cached_get_card_strength(self, card, trump_suit, led_suit): + """ + Retrieves cached card strength. + """ + + key = (card, trump_suit, led_suit) + return self.card_strength_cache[key] + + def _shallow_copy_state(self, state_info): + """ + Manual state copy to avoid overhead. + """ + + new_state = { + 'player_id': state_info['player_id'], + 'my_score': state_info['my_score'], + 'opponent_score': state_info['opponent_score'], + 'my_cards': list(state_info['my_cards']), + 'opponent_visible_cards': list(state_info['opponent_visible_cards']), + 'played_cards': set(state_info['played_cards']), + 'trick_moves': list(state_info['trick_moves']), + 'trump_suit': state_info['trump_suit'], + 'round_phase': state_info['round_phase'] + } + + return new_state + + def step(self, state): + """ + Selects a legal action using SHOT and Alpha-Beta search with move ordering. + """ + + legal_action_ids = list(state['legal_actions'].keys()) + raw_info = state['raw_state_info'] + round_phase = raw_info['round_phase'] + + # Simple rule based trump declaration + if round_phase == 'declare_trump': + my_cards = raw_info['my_cards'] + num_jacks = sum(1 for card in my_cards if card.rank == 'J') + + if num_jacks >= 2: + for action_id in legal_action_ids: + action = ActionEvent.from_action_id(action_id) + if isinstance(action, DeclareTrumpAction) and action.trump_suit == 'G': + return action_id + + suit_counts = Counter(card.suit for card in my_cards) + if not suit_counts: return self.rng.choice(legal_action_ids) + + max_count = max(suit_counts.values()) + best_suits = [suit for suit, count in suit_counts.items() if count == max_count] + + if len(best_suits) == 1: + best_suit_choice = best_suits[0] + else: + best_rank_val = -1 + rank_order = BauernskatCard.ranks + best_suit_choice = best_suits[0] + for suit in best_suits: + cards_of_suit = [card for card in my_cards if card.suit == suit] + max_rank_in_suit = max(rank_order.index(c.rank) for c in cards_of_suit) + if max_rank_in_suit > best_rank_val: + best_rank_val = max_rank_in_suit + best_suit_choice = suit + + for action_id in legal_action_ids: + action = ActionEvent.from_action_id(action_id) + if isinstance(action, DeclareTrumpAction) and action.trump_suit == best_suit_choice: + return action_id + + return self.rng.choice(legal_action_ids) + + # Play with SHOT and Alpha-Beta search + if round_phase == 'play': + legal_actions = [ActionEvent.from_action_id(aid) for aid in legal_action_ids if isinstance(ActionEvent.from_action_id(aid), PlayCardAction)] + if not legal_actions: return self.rng.choice(legal_action_ids) + if len(legal_actions) == 1: return legal_actions[0].action_id + + # Using SHOT for filtering actions + if self.use_shot: + candidate_actions = list(legal_actions) + num_rounds = math.ceil(math.log2(len(candidate_actions))) + sims_per_round = self.num_simulations // num_rounds if num_rounds > 0 else self.num_simulations + + for _ in range(num_rounds): + if len(candidate_actions) == 1: break + sims_per_candidate = max(1, sims_per_round // len(candidate_actions)) + scores = {action.action_id: 0 for action in candidate_actions} + + for action in candidate_actions: + post_move_state = self._shallow_copy_state(raw_info) + card_to_play = action.card + post_move_state['my_cards'].remove(card_to_play) + post_move_state['trick_moves'].append((post_move_state['player_id'], card_to_play)) + + if len(post_move_state['trick_moves']) == 2: + self._resolve_trick(post_move_state, original_player_id=raw_info['player_id']) + else: + post_move_state['player_id'] = 1 - post_move_state['player_id'] + + opponent_hand = self._determinize(post_move_state) + + for _ in range(sims_per_candidate): + sim_state = self._shallow_copy_state(post_move_state) + my_hand = list(sim_state['my_cards']) + if self.use_alpha_beta: + score = self._run_alpha_beta( + state_info=sim_state, + p0_hand=my_hand, + p1_hand=opponent_hand, + depth=self.alpha_beta_depth, + alpha=-999999, + beta=999999, + is_maximizing=(sim_state['player_id'] == raw_info['player_id']), + original_player_id=raw_info['player_id'] + ) + else: + score = self._run_heuristic_playout(sim_state, my_hand, opponent_hand, raw_info['player_id']) + scores[action.action_id] += score + + sorted_actions = sorted(candidate_actions, key=lambda a: scores[a.action_id], reverse=True) + num_to_keep = math.ceil(len(sorted_actions) / 2) + candidate_actions = sorted_actions[:num_to_keep] + + return candidate_actions[0].action_id if candidate_actions else self.rng.choice(legal_action_ids) + + else: + # PIMC + best_action = None + best_avg_score = float('-inf') + + for action in legal_actions: + post_move_state = self._shallow_copy_state(raw_info) + card_to_play = action.card + post_move_state['my_cards'].remove(card_to_play) + post_move_state['trick_moves'].append((post_move_state['player_id'], card_to_play)) + + if len(post_move_state['trick_moves']) == 2: + self._resolve_trick(post_move_state, original_player_id=raw_info['player_id']) + else: + post_move_state['player_id'] = 1 - post_move_state['player_id'] + + opponent_hand = self._determinize(post_move_state) + + total_score = 0 + + for _ in range(self.num_simulations): + sim_state = self._shallow_copy_state(post_move_state) + my_hand = list(sim_state['my_cards']) + if self.use_alpha_beta: + score = self._run_alpha_beta( + state_info=sim_state, + p0_hand=my_hand, + p1_hand=opponent_hand, + depth=self.alpha_beta_depth, + alpha=-999999, + beta=999999, + is_maximizing=(sim_state['player_id'] == raw_info['player_id']), + original_player_id=raw_info['player_id'] + ) + else: + score = self._run_heuristic_playout(sim_state, my_hand, opponent_hand, raw_info['player_id']) + total_score += score + + avg_score = total_score / self.num_simulations + if avg_score > best_avg_score: + best_avg_score = avg_score + best_action = action + + return best_action.action_id if best_action else self.rng.choice(legal_action_ids) + + return self.rng.choice(legal_action_ids) + + def _determinize(self, state_info): + """ + Creates a state for the opponent by using publicly known information. + """ + + all_cards = set(BauernskatCard.get_deck()) + + # Rule out cards that are not in the opponents closed cards + opponent_visible = set(state_info['opponent_visible_cards']) + played = set(state_info['played_cards']) + in_trick = {c for _, c in state_info['trick_moves']} + my_hand = set(state_info['my_cards']) + + publicly_known_cards = opponent_visible | played | in_trick | my_hand + + unknown_cards_pool = list(all_cards - publicly_known_cards) + self.rng.shuffle(unknown_cards_pool) + + num_opponent_hidden = 16 - len(state_info['opponent_visible_cards']) + + num_to_draw = min(num_opponent_hidden, len(unknown_cards_pool)) + + opponent_hidden_cards = unknown_cards_pool[:num_to_draw] + + return state_info['opponent_visible_cards'] + opponent_hidden_cards + + def _run_heuristic_playout(self, state_info, p0_hand, p1_hand, original_player_id): + """ + Runs a heuristic simulation. + """ + + playout_state = self._shallow_copy_state(state_info) + p0_h = set(p0_hand) + p1_h = set(p1_hand) + + num_played = len(playout_state['played_cards']) + tricks_to_play = (32 - num_played - len(playout_state['trick_moves'])) // 2 + + if len(playout_state['trick_moves']) == 1: + current_hand = p1_h if playout_state['player_id'] == (1 - original_player_id) else p0_h + legal_moves = self._get_legal_in_playout(current_hand, playout_state['trick_moves'], playout_state['trump_suit']) + if legal_moves: + move = max(legal_moves, key=lambda c: c.points) + current_hand.remove(move) + playout_state['trick_moves'].append((playout_state['player_id'], move)) + self._resolve_trick(playout_state, original_player_id) + + for _ in range(tricks_to_play): + leader_hand = p0_h if playout_state['player_id'] == original_player_id else p1_h + follower_hand = p1_h if playout_state['player_id'] == original_player_id else p0_h + + leader_legal = self._get_legal_in_playout(leader_hand, [], playout_state['trump_suit']) + if not leader_legal: break + leader_move = max(leader_legal, key=lambda c: c.points) + leader_hand.remove(leader_move) + playout_state['trick_moves'].append((playout_state['player_id'], leader_move)) + + follower_legal = self._get_legal_in_playout(follower_hand, playout_state['trick_moves'], playout_state['trump_suit']) + if not follower_legal: break + follower_move = max(follower_legal, key=lambda c: c.points) + follower_hand.remove(follower_move) + playout_state['trick_moves'].append((1 - playout_state['player_id'], follower_move)) + + self._resolve_trick(playout_state, original_player_id) + + return playout_state['my_score'] - playout_state['opponent_score'] + + def _run_alpha_beta(self, state_info, p0_hand, p1_hand, depth, alpha, beta, is_maximizing, original_player_id): + """ + Runs the alpha-beta pruning on the given state. + """ + + if depth == 0 or (len(p0_hand) == 0 and len(p1_hand) == 0 and not state_info['trick_moves']): + return self._evaluate_state(state_info, p0_hand, p1_hand, original_player_id) + + current_player_id = state_info['player_id'] + current_hand = p0_hand if current_player_id == original_player_id else p1_hand + legal_moves = self._get_legal_in_playout(current_hand, state_info['trick_moves'], state_info['trump_suit']) + if not legal_moves: + return self._evaluate_state(state_info, p0_hand, p1_hand, original_player_id) + + trump_suit = state_info['trump_suit'] + non_trump_hand = [c for c in current_hand if not _is_trump(c, trump_suit)] + suit_counts = Counter(c.suit for c in non_trump_hand) + if self.use_move_ordering: + sorted_moves = sorted(legal_moves, key=lambda card: self._advanced_heuristic_move_score(card, current_hand, state_info, suit_counts), reverse=True) + else: + sorted_moves = legal_moves + + if is_maximizing: + max_eval = -999999 + for card in sorted_moves: + child_state = self._get_next_state(state_info, p0_hand, p1_hand, card, current_player_id, original_player_id) + eval_score = self._run_alpha_beta(child_state['state'], child_state['p0_hand'], child_state['p1_hand'], depth - 1, alpha, beta, not is_maximizing, original_player_id) + max_eval = max(max_eval, eval_score) + alpha = max(alpha, eval_score) + if beta <= alpha: break + return max_eval + else: + min_eval = 999999 + for card in sorted_moves: + child_state = self._get_next_state(state_info, p0_hand, p1_hand, card, current_player_id, original_player_id) + eval_score = self._run_alpha_beta(child_state['state'], child_state['p0_hand'], child_state['p1_hand'], depth - 1, alpha, beta, not is_maximizing, original_player_id) + min_eval = min(min_eval, eval_score) + beta = min(beta, eval_score) + if beta <= alpha: break + return min_eval + + def _advanced_heuristic_move_score(self, card, hand, state_info, suit_counts=None): + """ + Create a score for move ordering in alpha-beta search. + """ + + trump_suit = state_info['trump_suit'] + is_creating_void = False + if not _is_trump(card, trump_suit): + if suit_counts is None: + count_in_suit = sum(1 for c in hand if c.suit == card.suit and not _is_trump(c, trump_suit)) + else: + count_in_suit = suit_counts.get(card.suit, 0) + if count_in_suit == 1: is_creating_void = True + + if state_info['trick_moves']: + led_card = state_info['trick_moves'][0][1] + led_suit = led_card.suit + led_strength = self._cached_get_card_strength(led_card, trump_suit, led_suit) + my_strength = self._cached_get_card_strength(card, trump_suit, led_suit) + + if my_strength > led_strength: + trick_points = led_card.points + card.points + if trick_points >= 10: return 500 + trick_points + else: return 400 - my_strength + else: + score = 200 + (11 - card.points) + if is_creating_void: score += 100 + return score + else: + if card.rank == 'J' and trump_suit != 'G': return 600 + self._cached_get_card_strength(card, trump_suit, card.suit) + if card.suit == trump_suit and trump_suit != 'G': return 500 + self._cached_get_card_strength(card, trump_suit, card.suit) + if card.points >= 10: return 400 + card.points + score = 200 - card.points + if is_creating_void: score += 100 + return score + + def _evaluate_state(self, state_info, p0_hand, p1_hand, original_player_id): + """ + Evaluate the current game state. + """ + + score_diff = state_info['my_score'] - state_info['opponent_score'] + p0_hand_value = sum(c.points for c in p0_hand) + p1_hand_value = sum(c.points for c in p1_hand) + hand_diff = p0_hand_value - p1_hand_value + if state_info['player_id'] == original_player_id: + score_diff += hand_diff // 10 + else: + score_diff -= hand_diff // 10 + return score_diff + + def _get_next_state(self, state_info, p0_hand, p1_hand, played_card, player_id, original_player_id): + """ + Gives the next state after a card is played. + """ + + next_state = self._shallow_copy_state(state_info) + p0_h = list(p0_hand) + p1_h = list(p1_hand) + if player_id == original_player_id: + if played_card in p0_h: p0_h.remove(played_card) + else: + if played_card in p1_h: p1_h.remove(played_card) + next_state['trick_moves'].append((player_id, played_card)) + if len(next_state['trick_moves']) == 2: + self._resolve_trick(next_state, original_player_id) + else: + next_state['player_id'] = 1 - player_id + return {'state': next_state, 'p0_hand': p0_h, 'p1_hand': p1_h} + + def _resolve_trick(self, state_info, original_player_id): + """ + Resolves a trick and updates the game state. + """ + + p0_move = state_info['trick_moves'][0] + p1_move = state_info['trick_moves'][1] + led_suit = p0_move[1].suit + p0_strength = self._cached_get_card_strength(p0_move[1], state_info['trump_suit'], led_suit) + p1_strength = self._cached_get_card_strength(p1_move[1], state_info['trump_suit'], led_suit) + winner_id = p0_move[0] if p0_strength > p1_strength else p1_move[0] + trick_points = p0_move[1].points + p1_move[1].points + if winner_id == original_player_id: + state_info['my_score'] += trick_points + else: + state_info['opponent_score'] += trick_points + state_info['played_cards'].add(p0_move[1]) + state_info['played_cards'].add(p1_move[1]) + state_info['trick_moves'] = [] + state_info['player_id'] = winner_id + + def _get_legal_in_playout(self, hand, trick_moves, trump_suit): + """ + Returns the legal cards in a simulation. + """ + + if not hand: return [] + if not trick_moves: return list(hand) + led_card = trick_moves[0][1] + trumps_in_hand = {card for card in hand if _is_trump(card, trump_suit)} + if _is_trump(led_card, trump_suit): + if trumps_in_hand: return list(trumps_in_hand) + else: + led_suit = led_card.suit + suit_in_hand = {card for card in hand if card.suit == led_suit and not _is_trump(card, trump_suit)} + if suit_in_hand: return list(suit_in_hand) + return list(hand) + + def eval_step(self, state): + """ + Selects a legal action for evaluation. + """ + + action = self.step(state) + return action, {} \ No newline at end of file diff --git a/rlcard/agents/bauernskat/sac_agent/__init__.py b/rlcard/agents/bauernskat/sac_agent/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/rlcard/agents/bauernskat/sac_agent/agent.py b/rlcard/agents/bauernskat/sac_agent/agent.py new file mode 100644 index 000000000..6478b13b9 --- /dev/null +++ b/rlcard/agents/bauernskat/sac_agent/agent.py @@ -0,0 +1,310 @@ +''' + File name: rlcard/games/bauernskat/sac_agent/agent.py + Author: Oliver Czerwinski + Date created: 11/10/2025 + Date last modified: 12/26/2025 + Python Version: 3.9+ +''' + +import random +from collections import Counter +from typing import Dict, Tuple, List, Optional + +import numpy as np +import torch +import torch.nn.functional as F +from torch.distributions import Categorical +from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts + +from rlcard.envs.env import Env +from rlcard.games.bauernskat.action_event import ActionEvent, DeclareTrumpAction +from rlcard.games.bauernskat.card import BauernskatCard +from rlcard.agents.bauernskat import rule_agents as bauernskat_rule_agents + +from rlcard.agents.bauernskat.sac_agent.model import BauernskatNet +from rlcard.agents.bauernskat.sac_agent.config import BauernskatNetConfig, TrainerConfig + +class SACEstimator: + """ + Soft Actor-Critic estimator with Actor and Twin Critics. + """ + def __init__(self, net_config: BauernskatNetConfig, train_config: TrainerConfig, device: torch.device): + self.device = device + self.gamma = train_config.gamma + self.tau = train_config.tau + + self.net = BauernskatNet(net_config).to(device) + self.target_net = BauernskatNet(net_config).to(device) + self.target_net.load_state_dict(self.net.state_dict()) + self.target_net.eval() + + self.optimizer = torch.optim.AdamW( + self.net.parameters(), + lr=train_config.critic_lr, + weight_decay=train_config.weight_decay + ) + + self.use_scheduler = train_config.use_lr_scheduler + if self.use_scheduler: + t0_steps = max(1, int(train_config.cosine_T0 / train_config.batch_size)) + + self.scheduler = CosineAnnealingWarmRestarts( + self.optimizer, + T_0=t0_steps, + T_mult=train_config.cosine_T_mult, + eta_min=train_config.cosine_eta_min + ) + + self.learn_alpha = train_config.learn_alpha + self.target_entropy = -np.log(1.0 / 8.0) * train_config.target_entropy_ratio + + if self.learn_alpha: + self.log_alpha = torch.zeros(1, requires_grad=True, device=device) + self.alpha_optim = torch.optim.Adam([self.log_alpha], lr=train_config.alpha_lr) + self.alpha = self.log_alpha.exp() + else: + self.alpha = torch.tensor(train_config.initial_alpha).to(device) + + def train_step(self, batch: dict, clip_norm: float) -> Tuple[float, float, float, float, float]: + """ + Performs a training step using a batch of transitions. + """ + + # Batch Tensors + states = batch.get('observation').to(self.device) + actions = batch.get('action').long().to(self.device) + rewards = batch.get(('next', 'reward')).to(self.device) + dones = batch.get(('next', 'done')).float().to(self.device) + next_states = batch.get(('next', 'observation')).to(self.device) + masks = batch.get('legal_actions_mask').bool().to(self.device) + next_masks = batch.get(('next', 'legal_actions_mask')).bool().to(self.device) + + logits, q1_all, q2_all = self.net.evaluate_all_actions(states) + + # Q-Values + q1_pred = q1_all.gather(1, actions) + q2_pred = q2_all.gather(1, actions) + + # Target Q-Values + with torch.no_grad(): + next_logits, next_q1, next_q2 = self.target_net.evaluate_all_actions(next_states) + + next_logits[~next_masks] = -1e8 + next_probs = F.softmax(next_logits, dim=-1) + next_log_probs = F.log_softmax(next_logits, dim=-1) + + min_next_q = torch.min(next_q1, next_q2) + + target_v = torch.sum(next_probs * (min_next_q - self.alpha * next_log_probs), dim=-1, keepdim=True) + target_q = rewards + (1 - dones) * self.gamma * target_v + + critic_loss = F.mse_loss(q1_pred, target_q) + F.mse_loss(q2_pred, target_q) + + q1_detach = q1_all.detach() + q2_detach = q2_all.detach() + min_q = torch.min(q1_detach, q2_detach) + + logits[~masks] = -1e8 + probs = F.softmax(logits, dim=-1) + log_probs = F.log_softmax(logits, dim=-1) + + actor_loss = torch.sum(probs * (self.alpha * log_probs - min_q), dim=-1).mean() + + alpha_loss = 0.0 + curr_alpha = self.alpha.item() if isinstance(self.alpha, torch.Tensor) else self.alpha + + # Update Entropy Temperature + if self.learn_alpha: + with torch.no_grad(): + entropy = -torch.sum(probs * log_probs, dim=-1).mean() + + alpha_loss = -(self.log_alpha * (entropy - self.target_entropy).detach()) + + self.alpha_optim.zero_grad() + alpha_loss.backward() + self.alpha_optim.step() + + self.alpha = self.log_alpha.exp() + curr_alpha = self.alpha.item() + + # Update Netoworks + total_loss = critic_loss + actor_loss + self.optimizer.zero_grad() + total_loss.backward() + torch.nn.utils.clip_grad_norm_(self.net.parameters(), clip_norm) + self.optimizer.step() + + current_lr = self.optimizer.param_groups[0]['lr'] + if self.use_scheduler: + self.scheduler.step() + current_lr = self.scheduler.get_last_lr()[0] + + return critic_loss.item(), actor_loss.item(), curr_alpha, q1_pred.mean().item(), current_lr + + def update_target_net(self): + """ + Soft-update for the target network. + """ + + with torch.no_grad(): + for param, target_param in zip(self.net.parameters(), self.target_net.parameters()): + target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) + + +class AgentSAC_Actor: + """ + An agent that uses a Soft Actor-Critic network for decision making. + """ + + def __init__(self, net: BauernskatNet, device: str = 'cpu', use_teacher: bool = False): + """ + Initializes AgentSAC_Actor. + """ + + self.net = net + self.device = device + + self.teacher = bauernskat_rule_agents.BauernskatLookaheadRuleAgent() if use_teacher else None + + @staticmethod + def _map_action_to_sac_idx(action_id: int) -> int: + """ + Maps action_id to SAC index. + """ + + if action_id < 5: return 32 + action_id + else: return action_id - 5 + + @staticmethod + def _map_sac_idx_to_action(idx: int) -> int: + """ + Maps SAC index back to action_id. + """ + + if idx >= 32: return idx - 32 + else: return idx + 5 + + def _get_rule_based_trump_action(self, state: dict) -> Optional[int]: + """ + Heuristic for trump selection: + - If 2 or more Jacks: Declare Grand ('G') + - Else: Suit with most cards. + - Otherwise: Suit with highest rank card. + """ + + legal_action_ids = list(state['legal_actions'].keys()) + raw_info = state.get('raw_state_info') + if not raw_info: return None + + my_cards = raw_info.get('my_cards', []) + if not my_cards: return random.choice(legal_action_ids) + + # Jacks + num_jacks = sum(1 for card in my_cards if card.rank == 'J') + if num_jacks >= 2: + for action_id in legal_action_ids: + action = ActionEvent.from_action_id(action_id) + if isinstance(action, DeclareTrumpAction) and action.trump_suit == 'G': + return action_id + + # Suit counts + suit_counts = Counter(card.suit for card in my_cards) + if not suit_counts: + return random.choice(legal_action_ids) + + max_count = max(suit_counts.values()) + best_suits = [suit for suit, count in suit_counts.items() if count == max_count] + + # Highest rank suit card + if len(best_suits) == 1: + best_suit_choice = best_suits[0] + else: + best_rank_val = -1 + + rank_order = BauernskatCard.ranks + best_suit_choice = best_suits[0] + + for suit in best_suits: + cards_of_suit = [card for card in my_cards if card.suit == suit] + if cards_of_suit: + max_rank_in_suit = max(rank_order.index(c.rank) for c in cards_of_suit) + if max_rank_in_suit > best_rank_val: + best_rank_val = max_rank_in_suit + best_suit_choice = suit + + for action_id in legal_action_ids: + action = ActionEvent.from_action_id(action_id) + if isinstance(action, DeclareTrumpAction) and action.trump_suit == best_suit_choice: + return action_id + + return random.choice(legal_action_ids) + + def step(self, state: dict, env: Env, trump_rule_prob: float = 0.0, teacher_epsilon: float = 0.0) -> Tuple[int, List[int]]: + """ + Chooses an action based on SAC policy. + """ + + legal_actions = list(state['legal_actions'].keys()) + if not legal_actions: + return -1, [] + + r = random.random() + + # Teacher Forcing + if self.teacher is not None and r < teacher_epsilon: + action = self.teacher.step(state) + return action, legal_actions + + # 2. Rule-Based Trump Selection + if trump_rule_prob > 0.0: + raw_info = state.get('raw_state_info') + if raw_info and raw_info.get('round_phase') == 'declare_trump': + if random.random() < trump_rule_prob: + action = self._get_rule_based_trump_action(state) + if action is not None: + return action, legal_actions + + # 3. SAC Policy + legal_indices = [self._map_action_to_sac_idx(a) for a in legal_actions] + + mask = torch.zeros(38, dtype=torch.bool) + mask[legal_indices] = True + + with torch.no_grad(): + state_batch = {k: torch.from_numpy(np.array(v)).unsqueeze(0).to(self.device) + for k, v in state['obs'].items()} + + logits, _, _ = self.net.evaluate_all_actions(state_batch) + logits = logits.squeeze(0).cpu() + + logits[~mask] = -float('inf') + + probs = F.softmax(logits, dim=-1) + dist = Categorical(probs) + sac_idx = dist.sample().item() + + return self._map_sac_idx_to_action(sac_idx), legal_actions + + def eval_step(self, state: dict, env: Env) -> Tuple[int, Dict]: + """ + Chooses the best action without exploration. + """ + + legal_actions = list(state['legal_actions'].keys()) + legal_indices = [self._map_action_to_sac_idx(a) for a in legal_actions] + + mask = torch.zeros(38, dtype=torch.bool) + mask[legal_indices] = True + + with torch.no_grad(): + state_batch = {k: torch.from_numpy(np.array(v)).unsqueeze(0).to(self.device) + for k, v in state['obs'].items()} + + logits, _, _ = self.net.evaluate_all_actions(state_batch) + logits = logits.squeeze(0).cpu() + + logits[~mask] = -float('inf') + + sac_idx = torch.argmax(logits).item() + + return self._map_sac_idx_to_action(sac_idx), {} \ No newline at end of file diff --git a/rlcard/agents/bauernskat/sac_agent/config.py b/rlcard/agents/bauernskat/sac_agent/config.py new file mode 100644 index 000000000..f2ae2bcb7 --- /dev/null +++ b/rlcard/agents/bauernskat/sac_agent/config.py @@ -0,0 +1,155 @@ +''' + File name: rlcard/games/bauernskat/sac_agent/config.py + Author: Oliver Czerwinski + Date created: 11/10/2025 + Date last modified: 12/26/2025 + Python Version: 3.9+ +''' + +from dataclasses import dataclass, field +from typing import Tuple, Literal +import torch + +# Bauernskat specific constants +MAX_PLAYER_CARDS = 16 +MAX_TRICK_SIZE = 2 +MAX_CEMETERY_SIZE = 32 + +# Model architecture +@dataclass +class BauernskatNetConfig: + """ + Configuration for the BauernskatNet model. + """ + + card_embedding_dim: int = 32 + branch_output_dim: int = 96 + + pool_type: Literal['mean', 'sum'] = 'mean' + + mlp_hidden_dims: Tuple[int, ...] = (64, 64) + indicator_mlp_dims: Tuple[int, ...] = (64, 64) + layout_processor_hidden_dim: int = 128 + mask_processor_hidden_dims: Tuple[int, ...] = (64,) + + num_lstm_layers: int = 2 + lstm_hidden_dim: int = 96 + use_bidirectional: bool = True + use_attention: bool = True + attn_heads: int = 4 + lstm_fc_dims: Tuple[int, ...] = (96,) + + context_vector_dim: int = 11 + indicator_vector_dim: int = 8 + action_history_frame_size: int = 49 + + head_hidden_dims: Tuple[int, ...] = (512, 256) + head_dropout: float = 0.0 + +# Training configuration +@dataclass +class TrainerConfig: + """ + Configuration for the DMCTrainer. + """ + + xpid: str = 'sac_agent_bauernskat_v1' + savedir: str = 'experiments/sac_agent_result' + load_model: bool = True + save_every_frames: int = 4_096_000 + seed: int = 21000 + + # Logging + log_to_tensorboard: bool = True + log_p0_p1_payoffs: bool = True + log_every_frames: int = 4_096 + log_interval_seconds: float = 5.0 + + # Pipeline & Threading + cuda: str = '0' + training_device: str = "0" + num_actors: int = 10 + num_threads: int = 1 + actor_queue_size_multiplier: int = 64 + actor_game_batch_size: int = 1 + process_join_timeout: float = 5.0 + sample_queue_put_timeout: float = 5.0 + + # Training Hyperparameters + batch_size: int = 1024 + gamma: float = 0.99 + tau: float = 0.005 + n_step_returns: int = 3 + actor_lr: float = 3e-4 + critic_lr: float = 1.5e-4 + alpha_lr: float = 3e-4 + gradient_clip_norm: float = 1.0 + weight_decay: float = 1e-3 + + use_lr_scheduler: bool = True + cosine_T0: int = 5_120_000 + cosine_T_mult: int = 2 + cosine_eta_min: float = 3e-6 + + # Entropy parameters + initial_alpha: float = 1.0 + learn_alpha: bool = True + target_entropy_ratio: float = 0.98 + + # Replay Buffer + replay_buffer_size: int = 204_800 + min_buffer_size_to_learn: int = 8_192 + + # Prioritized Experience Replay + per_alpha: float = 0.6 + per_beta: float = 0.4 + + # Reward Function + reward_type: Literal['game_score', 'binary', 'hybrid'] = 'hybrid' + + # Parameters for 'hybrid' reward function + max_reward_abs: float = 480.0 + reward_shaping_steepness: float = 0.009 + reward_shaping_threshold: int = 18 + reward_shaping_score_weight: float = 0.5 + reward_shaping_win_bonus: float = 1.0 + + # Rule-Based Trump Selection + use_rule_based_trump_decay: bool = False + trump_start: float = 1.0 + trump_end: float = 0.0 + trump_decay_frames: int = 1_024_000_000 + + # Teacher Forcing + use_teacher_forcing: bool = False + teacher_start: float = 1.0 + teacher_end: float = 0.0 + teacher_decay_frames: int = 64_000_000 + + # Environment and Evaluation + env: str = 'bauernskat' + information_level: Literal['normal', 'show_self', 'perfect'] = 'normal' + total_frames: int = 1_024_000_000 + eval_every: int = 4_096_000 + num_eval_games: int = 512 + + # Model Configuration + model_config: BauernskatNetConfig = field(default_factory=BauernskatNetConfig) + device: torch.device = field(init=False) + + def __post_init__(self): + """ + Sets the device and validate some hyperparameters. + """ + + if self.training_device != "cpu" and torch.cuda.is_available(): + self.device = torch.device(f"cuda:{self.training_device}") + else: + self.device = torch.device("cpu") + + if self.min_buffer_size_to_learn > self.replay_buffer_size: + raise ValueError("min_buffer_size_to_learn cannot be larger than replay_buffer_size") + if self.batch_size > self.min_buffer_size_to_learn: + raise ValueError("batch_size cannot be larger than min_buffer_size_to_learn") + if self.num_actors <= 0: + raise ValueError("num_actors must be a positive integer") \ No newline at end of file diff --git a/rlcard/agents/bauernskat/sac_agent/model.py b/rlcard/agents/bauernskat/sac_agent/model.py new file mode 100644 index 000000000..835e8130e --- /dev/null +++ b/rlcard/agents/bauernskat/sac_agent/model.py @@ -0,0 +1,289 @@ +''' + File name: rlcard/games/bauernskat/sac_agent/model.py + Author: Oliver Czerwinski + Date created: 11/10/2025 + Date last modified: 12/26/2025 + Python Version: 3.9+ +''' + +import torch +import torch.nn as nn +from typing import Dict, Tuple +from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence + +from rlcard.agents.bauernskat.sac_agent.config import BauernskatNetConfig + +class ResidualBlock(nn.Module): + """ + Basic residual block. + """ + + def __init__(self, dim: int): + """ + Initializes ResidualBlock. + """ + + super().__init__() + self.layers = nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, dim * 4), + nn.GELU(), + nn.Linear(dim * 4, dim) + ) + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Outputs the result of the residual block. + """ + + return x + self.layers(x) + +class LayoutProcessor(nn.Module): + """ + Processes a (8, 2) layout tensor. + """ + + def __init__(self, shared_card_embedding: nn.Embedding, output_dim: int, hidden_dim: int): + """ + Initializes LayoutProcessor. + """ + + super().__init__() + self.embedding = shared_card_embedding + embedding_dim = self.embedding.embedding_dim + + input_size = 8 * 2 * embedding_dim + self.mlp = nn.Sequential( + nn.Linear(input_size, hidden_dim), + nn.GELU(), + nn.LayerNorm(hidden_dim), + nn.Linear(hidden_dim, output_dim) + ) + + def forward(self, layout_tensor: torch.Tensor) -> torch.Tensor: + """ + Outputs an embedding for the layout tensor. + """ + + embedded = self.embedding(layout_tensor) + flattened = embedded.view(embedded.shape[0], -1) + return self.mlp(flattened) + +class CardSetProcessor(nn.Module): + """ + Processes a flexible sized set of cards. + """ + + def __init__(self, shared_card_embedding: nn.Embedding, output_dim: int, pool_type: str = 'mean'): + + + super().__init__() + self.embedding = shared_card_embedding + self.pool_type = pool_type + self.padding_idx = shared_card_embedding.padding_idx + + self.mlp = nn.Sequential( + nn.Linear(self.embedding.embedding_dim, output_dim), + nn.GELU(), + nn.LayerNorm(output_dim) + ) + def forward(self, ids: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Outputs an embedding for the set of cards. + """ + + if ids.shape[1] == 0: + return torch.zeros(ids.shape[0], self.mlp[0].out_features, device=ids.device) + + safe_ids = ids.clone() + if self.padding_idx is not None: + safe_ids[ids == -1] = self.padding_idx + + embedded = self.embedding(safe_ids) + if self.pool_type == 'mean': + num_cards = mask.sum(dim=1, keepdim=True).clamp(min=1) + pooled = embedded.sum(dim=1) / num_cards + elif self.pool_type == 'sum': + pooled = embedded.sum(dim=1) + + return self.mlp(pooled) + +class BauernskatNet(nn.Module): + """ + SAC Network for Bauernskat. + """ + + def __init__(self, config: BauernskatNetConfig): + """ + Initializes BauernskatNet. + """ + + super().__init__() + self.config = config + + self.card_embedding = nn.Embedding(33, config.card_embedding_dim, padding_idx=32) + self.trump_action_embedding = nn.Embedding(6, config.card_embedding_dim, padding_idx=5) + + card_set_args = (self.card_embedding, config.branch_output_dim, config.pool_type) + layout_proc_args = (self.card_embedding, config.branch_output_dim, config.layout_processor_hidden_dim) + + self.my_layout_processor = LayoutProcessor(*layout_proc_args) + self.opponent_layout_processor = LayoutProcessor(*layout_proc_args) + self.unaccounted_mask_processor = self._build_mlp(32, list(config.mask_processor_hidden_dims), config.branch_output_dim) + self.trick_processor = CardSetProcessor(*card_set_args) + self.cemetery_processor = CardSetProcessor(*card_set_args) + + self.indicator_processor = self._build_mlp(config.indicator_vector_dim, list(config.indicator_mlp_dims), config.branch_output_dim) + self.context_processor = self._build_mlp(config.context_vector_dim, list(config.mlp_hidden_dims), config.branch_output_dim) + + lstm_out_dim = config.lstm_hidden_dim * (2 if config.use_bidirectional else 1) + self.lstm = nn.LSTM(config.action_history_frame_size, config.lstm_hidden_dim, config.num_lstm_layers, + bidirectional=config.use_bidirectional, batch_first=True) + + self.attn = nn.MultiheadAttention(lstm_out_dim, config.attn_heads, batch_first=True) if config.use_attention else None + self.history_processor = self._build_mlp(lstm_out_dim, list(config.lstm_fc_dims), config.branch_output_dim) + + self.action_card_processor = CardSetProcessor(*card_set_args) + self.trump_action_processor = nn.Sequential( + nn.Linear(config.card_embedding_dim, config.branch_output_dim), + nn.GELU(), + nn.LayerNorm(config.branch_output_dim) + ) + + # SAC Heads + concat_dim = config.branch_output_dim * 8 + + head_input_dim = concat_dim + config.branch_output_dim + + def build_head(): + layers = [] + dims = [head_input_dim] + list(config.head_hidden_dims) + for i in range(len(config.head_hidden_dims)): + layers.extend([ + nn.Linear(dims[i], dims[i+1]), + ResidualBlock(dims[i+1]), + nn.GELU() + ]) + layers.append(nn.Linear(dims[-1], 1)) + return nn.Sequential(*layers) + + self.actor_head = build_head() + self.critic1_head = build_head() + self.critic2_head = build_head() + + self.register_buffer('all_actions_indices', torch.arange(38, dtype=torch.long)) + + @staticmethod + def _build_mlp(input_dim, hidden_dims, output_dim): + """ + Creates an MLP with specific dimensions. + """ + + layers = [] + curr = input_dim + for h in hidden_dims: + layers.extend([nn.Linear(curr, h), nn.GELU()]) + curr = h + layers.append(nn.Linear(curr, output_dim)) + return nn.Sequential(*layers) + + def _forward_history(self, x: torch.Tensor) -> torch.Tensor: + """ + Processes the action history using LSTM and attention. + """ + + B, _, _ = x.shape + lengths = torch.sum(x.abs().sum(dim=-1) > 0, dim=-1) + full_batch_summary = torch.zeros(B, self.config.branch_output_dim, device=x.device) + + non_empty_mask = lengths > 0 + if not non_empty_mask.any(): + return full_batch_summary + + non_empty_x = x[non_empty_mask] + non_empty_lengths = lengths[non_empty_mask] + non_empty_indices = non_empty_mask.nonzero(as_tuple=True)[0] + + sorted_lengths, sorted_indices = torch.sort(non_empty_lengths, descending=True) + sorted_x = non_empty_x.index_select(0, sorted_indices) + + packed_input = pack_padded_sequence(sorted_x, sorted_lengths.cpu(), batch_first=True, enforce_sorted=True) + packed_output, _ = self.lstm(packed_input) + lstm_out, _ = pad_packed_sequence(packed_output, batch_first=True, total_length=sorted_x.size(1)) + + _, unsorted_indices = torch.sort(sorted_indices) + lstm_out = lstm_out.index_select(0, unsorted_indices) + + if self.attn: + b_non, s_non, _ = lstm_out.shape + indices = torch.arange(s_non, device=x.device).expand(b_non, -1) + attn_mask = indices >= non_empty_lengths.unsqueeze(1) + last_seq_idxs = (non_empty_lengths - 1).clamp(min=0) + query = lstm_out[torch.arange(b_non), last_seq_idxs, :].unsqueeze(1) + attn_out, _ = self.attn(query=query, key=lstm_out, value=lstm_out, key_padding_mask=attn_mask) + summary = attn_out.squeeze(1) + else: + b_non = lstm_out.shape[0] + last_seq_idxs = (non_empty_lengths - 1).clamp(min=0) + summary = lstm_out[torch.arange(b_non), last_seq_idxs, :] + + processed = self.history_processor(summary) + full_batch_summary.index_add_(0, non_empty_indices, processed) + + return full_batch_summary + + def encode_state(self, state_obs: Dict[str, torch.Tensor]) -> torch.Tensor: + """ + Encodes the state observation into a fixed-size vector. + """ + + my_layout = self.my_layout_processor(state_obs['my_layout_tensor']) + opp_layout = self.opponent_layout_processor(state_obs['opponent_layout_tensor']) + mask_vec = self.unaccounted_mask_processor(state_obs['unaccounted_cards_mask']) + trick = self.trick_processor(state_obs['trick_card_ids'], state_obs['trick_card_ids'] != -1) + cemetery = self.cemetery_processor(state_obs['cemetery_card_ids'], state_obs['cemetery_card_ids'] != -1) + + my_ind = self.indicator_processor(state_obs['my_hidden_indicators']) + opp_ind = self.indicator_processor(state_obs['opponent_hidden_indicators']) + indicator = my_ind + opp_ind + + ctx = self.context_processor(state_obs['context']) + hist = self._forward_history(state_obs['action_history']) + + return torch.cat([my_layout, opp_layout, mask_vec, trick, cemetery, indicator, ctx, hist], dim=-1) + + def _process_all_actions(self): + """ + Processes all 38 actions into embeddings. + """ + + # Card Actions + card_indices = torch.arange(32, device=self.all_actions_indices.device).unsqueeze(1) + card_vecs = self.action_card_processor(card_indices, torch.ones_like(card_indices, dtype=torch.bool)) + + # Trump Actions + trump_indices = torch.arange(6, device=self.all_actions_indices.device).unsqueeze(1) + trump_embs = self.trump_action_embedding(trump_indices) + trump_vecs = self.trump_action_processor(trump_embs.squeeze(1)) + + return torch.cat([card_vecs, trump_vecs], dim=0) + + def evaluate_all_actions(self, state_obs: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Evaluates all possible actions for the given state. + """ + + batch_size = state_obs['context'].shape[0] + state_vec = self.encode_state(state_obs) + + action_vecs = self._process_all_actions() + + state_expanded = state_vec.unsqueeze(1).expand(-1, 38, -1) + action_expanded = action_vecs.unsqueeze(0).expand(batch_size, -1, -1) + + fused = torch.cat([state_expanded, action_expanded], dim=-1) + + logits = self.actor_head(fused).squeeze(-1) + q1 = self.critic1_head(fused).squeeze(-1) + q2 = self.critic2_head(fused).squeeze(-1) + + return logits, q1, q2 \ No newline at end of file diff --git a/rlcard/agents/bauernskat/sac_agent/reward.py b/rlcard/agents/bauernskat/sac_agent/reward.py new file mode 100644 index 000000000..dd03c4da6 --- /dev/null +++ b/rlcard/agents/bauernskat/sac_agent/reward.py @@ -0,0 +1,65 @@ +''' + File name: rlcard/games/bauernskat/sac_agent/reward.py + Author: Oliver Czerwinski + Date created: 11/10/2025 + Date last modified: 12/26/2025 + Python Version: 3.9+ +''' + +import numpy as np + +def _custom_centered_tanh(final_score: float, steepness: float, win_loss_threshold: int) -> float: + """ + Centered tanh function to compress score magnitudes. + """ + + if final_score >= win_loss_threshold: + adjusted_magnitude = float(final_score - win_loss_threshold) + return np.tanh(adjusted_magnitude * steepness) + elif final_score <= -win_loss_threshold: + adjusted_magnitude = float(abs(final_score) - win_loss_threshold) + return -np.tanh(adjusted_magnitude * steepness) + else: + return 0.0 + +def calculate_game_score_reward(final_score: float) -> float: + """ + Returns the raw game score as reward. + """ + + return float(final_score) + +def calculate_binary_reward(final_score: float) -> float: + """ + Returns +1.0 for win or -1.0 for loss as reward. + """ + + return float(np.sign(final_score)) + +def calculate_hybrid_reward(my_final_pips: int, opponent_final_pips: int, final_score: float, steepness: float = 0.009, threshold: int = 18, score_weight: float = 0.5, win_bonus_magnitude: float = 1.0) -> float: + """ + Calculates a hybrid reward based on game outcome, pip difference and score magnitude. + """ + + # Sign of the outcome + outcome_sign = 0.0 + if final_score >= threshold: + outcome_sign = 1.0 + elif final_score <= -threshold: + outcome_sign = -1.0 + + # Safety for draw + if outcome_sign == 0.0: + return 0.0 + + # Pip difference + r_base = float(my_final_pips - opponent_final_pips) + + # Score multiplier + compressed_score = _custom_centered_tanh(final_score, steepness=steepness, win_loss_threshold=threshold) + m_score = 1.0 + score_weight * abs(compressed_score) + + total_magnitude = win_bonus_magnitude + abs(r_base * m_score) + final_reward = outcome_sign * total_magnitude + + return final_reward \ No newline at end of file diff --git a/rlcard/agents/bauernskat/sac_agent/trainer.py b/rlcard/agents/bauernskat/sac_agent/trainer.py new file mode 100644 index 000000000..d4a9e7bb3 --- /dev/null +++ b/rlcard/agents/bauernskat/sac_agent/trainer.py @@ -0,0 +1,691 @@ +''' + File name: rlcard/games/bauernskat/sac_agent/trainer.py + Author: Oliver Czerwinski + Date created: 11/10/2025 + Date last modified: 12/26/2025 + Python Version: 3.9+ +''' + +import os +import pprint +import threading +import time +import datetime +import traceback +import copy +import csv +import json +import logging +import dataclasses +import random +import argparse +import queue +from typing import Dict, Any, Literal, get_origin, get_args +from multiprocessing.synchronize import Lock as LockType +from multiprocessing.queues import Queue as QueueType +from multiprocessing.sharedctypes import Synchronized as SynchronizedType + +import numpy as np +import torch +from torch import multiprocessing as mp +from torch import nn +from torch.utils.tensorboard import SummaryWriter + +from torchrl.data import TensorDictReplayBuffer, LazyTensorStorage +from torchrl.data.replay_buffers.samplers import PrioritizedSampler +from tensordict import TensorDict + +import rlcard +from rlcard.agents.bauernskat.sac_agent.config import TrainerConfig +from rlcard.agents.bauernskat.sac_agent.model import BauernskatNet +from rlcard.agents.bauernskat.sac_agent.agent import SACEstimator, AgentSAC_Actor +from rlcard.agents.bauernskat.sac_agent.utils import ObsPreprocessor, setup_logging, TrainingLogger, AgentEvaluator +from rlcard.agents.bauernskat.sac_agent.reward import calculate_hybrid_reward, calculate_binary_reward, calculate_game_score_reward + + +log = logging.getLogger('agent_sac_trainer') + +def format_time(seconds: float) -> str: + """ + Formats seconds into a HH:MM:SS. + """ + + return str(datetime.timedelta(seconds=int(seconds))) + +def gather_metadata(config: TrainerConfig) -> Dict: + """ + Gathers metadata about the training run. + """ + + date_start = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f') + + slurm_data = {k.replace('SLURM_', '').lower(): v for k, v in os.environ.items() if k.startswith('SLURM')} or None + env_whitelist = ('USER', 'HOSTNAME') + safe_env = {k: v for k, v in os.environ.items() if k.startswith('SLURM') or k in env_whitelist} + + def custom_dict_factory(data): + """ + Handles non-serializable types in dataclasses. + """ + + return {k: str(v) if isinstance(v, torch.device) else v for k, v in data} + + config_dict = dataclasses.asdict(config, dict_factory=custom_dict_factory) + + return dict(date_start=date_start, date_end=None, successful=False, + slurm=slurm_data, env=safe_env, config=config_dict) + + +class FileWriter: + """ + Handles logging to files and saving metadata. + """ + + def __init__(self, xpid: str, rootdir: str, config: TrainerConfig): + """ + Initializes FileWriter. + """ + + self.xpid = xpid + self._tick = 0 + self.metadata = gather_metadata(config) + self.metadata['xpid'] = self.xpid + + self._logger = logging.getLogger(f'filewriter/{self.xpid}') + self._logger.setLevel(logging.INFO) + self._logger.propagate = False + + self.basepath = os.path.join(os.path.expandvars(os.path.expanduser(rootdir)), self.xpid) + os.makedirs(self.basepath, exist_ok=True) + + self.paths = { + 'msg': f'{self.basepath}/out.log', 'logs': f'{self.basepath}/logs.csv', + 'fields': f'{self.basepath}/fields.csv', 'meta': f'{self.basepath}/meta.json'} + + self._save_metadata() + + fhandle = logging.FileHandler(self.paths['msg']) + fhandle.setFormatter(logging.Formatter('%(message)s')) + self._logger.addHandler(fhandle) + + self.fieldnames = ['_tick', '_time'] + + if os.path.exists(self.paths['logs']): + with open(self.paths['fields'], 'r') as csvfile: + self.fieldnames = list(csv.reader(csvfile))[0] + + def log(self, to_log: Dict): + """ + Logs values to a CSV file. + """ + + to_log.update({'_tick': self._tick, '_time': time.time()}) + self._tick += 1 + + new_fields = any(k not in self.fieldnames for k in to_log) + + if new_fields: + self.fieldnames.extend(k for k in to_log if k not in self.fieldnames) + with open(self.paths['fields'], 'w') as f: + csv.writer(f).writerow(self.fieldnames) + + if to_log['_tick'] == 1: + with open(self.paths['logs'], 'a') as f: + f.write(f'# {",".join(self.fieldnames)}\n') + + self._logger.info(f'LOG | {", ".join([f"{k}: {v}" for k,v in sorted(to_log.items())])}') + + with open(self.paths['logs'], 'a') as f: + csv.DictWriter(f, fieldnames=self.fieldnames).writerow(to_log) + + def close(self, successful: bool = True): + """ + Closes the FileWriter and saves final metadata. + """ + + self.metadata['date_end'] = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f') + self.metadata['successful'] = successful + + self._save_metadata() + + for handler in self._logger.handlers[:]: + if isinstance(handler, logging.FileHandler): + handler.close() + self._logger.removeHandler(handler) + + def _save_metadata(self): + """ + Saves metadata to a JSON file. + """ + + with open(self.paths['meta'], 'w') as f: + json.dump(self.metadata, f, indent=4, sort_keys=True) + + +def act(actor_id: int, + config: TrainerConfig, + actor_model: nn.Module, + sample_queue: QueueType, + log_queue: QueueType, + shared_trump_prob: SynchronizedType, + shared_teacher_eps: SynchronizedType, + dropped_batches_counter: SynchronizedType, + start_seed_offset: int): + """ + Main loop for an actor process. + """ + + setup_logging() + + seed = config.seed + actor_id + start_seed_offset + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + log = logging.getLogger('agent_sac_trainer') + obs_preprocessor = ObsPreprocessor() + + try: + log.info(f'Actor {actor_id} started (seed={seed}).') + env_config = { + 'seed': seed, + 'information_level': config.information_level} + env = rlcard.make(config.env, config=env_config) + + agent = AgentSAC_Actor(actor_model, 'cpu', use_teacher=config.use_teacher_forcing) + + # Main loop + while True: + aggregated_samples = [] + for _ in range(config.actor_game_batch_size): + trajectories = {p_id: [] for p_id in range(env.num_players)} + state, player_id = env.reset() + + while not env.is_over(): + action_id, legal_keys = agent.step(state, env, + trump_rule_prob=shared_trump_prob.value, + teacher_epsilon=shared_teacher_eps.value) + + sac_idx = 32 + action_id if action_id < 5 else action_id - 5 + + trajectories[player_id].append((state['obs'], sac_idx, legal_keys)) + state, player_id = env.step(action_id) + + # Reward Calculation + final_scores = env.get_payoffs() + final_pips = env.get_scores() + payoffs = np.zeros(2, dtype=np.float32) + + if config.reward_type == 'hybrid': + payoffs[0] = calculate_hybrid_reward(final_pips[0], final_pips[1], final_scores[0], + config.reward_shaping_steepness, config.reward_shaping_threshold, + config.reward_shaping_score_weight, config.reward_shaping_win_bonus) + payoffs[1] = calculate_hybrid_reward(final_pips[1], final_pips[0], final_scores[1], + config.reward_shaping_steepness, config.reward_shaping_threshold, + config.reward_shaping_score_weight, config.reward_shaping_win_bonus) + elif config.reward_type == 'binary': + payoffs[0] = calculate_binary_reward(final_scores[0]) + payoffs[1] = calculate_binary_reward(final_scores[1]) + elif config.reward_type == 'game_score': + payoffs[0] = calculate_game_score_reward(final_scores[0]) + payoffs[1] = calculate_game_score_reward(final_scores[1]) + + if config.log_p0_p1_payoffs: + log_queue.put({'p0_payoff': env.get_payoffs()[0], 'p1_payoff': env.get_payoffs()[1]}) + + # N-Step Return + for p_id, trajectory in trajectories.items(): + if not trajectory: continue + + G = float(payoffs[p_id]) + traj_len = len(trajectory) + + for i in range(traj_len): + obs, act_idx, legal_keys = trajectory[i] + n_step_end_idx = i + config.n_step_returns + + if n_step_end_idx < traj_len: + reward = 0.0 + done = False + next_obs, _, next_legal_keys = trajectory[n_step_end_idx] + else: + steps_to_end = (traj_len - 1) - i + reward = G * (config.gamma ** steps_to_end) + done = True + next_obs, _, next_legal_keys = trajectory[-1] + + sample = { + "observation": obs, + "action": [act_idx], + "legal_keys": legal_keys, + "next": { + "observation": next_obs, + "reward": [reward], + "done": [done], + "legal_keys": next_legal_keys + } + } + + processed_sample = obs_preprocessor(sample) + aggregated_samples.append(processed_sample) + + if aggregated_samples: + try: + sample_queue.put(aggregated_samples, timeout=config.sample_queue_put_timeout) + except queue.Full: + with dropped_batches_counter.get_lock(): + dropped_batches_counter.value += 1 + + except KeyboardInterrupt: + log.info(f"Actor {actor_id} interrupted.") + except Exception as e: + log.error(f'Exception in actor process {actor_id}: {e}\n{traceback.format_exc()}') + raise e + + +def learn(config: TrainerConfig, + estimator: SACEstimator, + actor_model: nn.Module, + replay_buffer: TensorDictReplayBuffer, + frames_counter: SynchronizedType, + learner_lock: LockType, + buffer_lock: LockType, + log_queue: QueueType, + latest_critic_loss: SynchronizedType, + latest_actor_loss: SynchronizedType, + latest_alpha: SynchronizedType, + latest_mean_q: SynchronizedType, + latest_lr: SynchronizedType): + """ + Main loop for a learner thread. + """ + + last_log_frame = 0 + + while frames_counter.value < config.total_frames: + with buffer_lock: + if len(replay_buffer) < config.min_buffer_size_to_learn: + time.sleep(1) + continue + try: + batch = replay_buffer.sample(config.batch_size) + except Exception: + time.sleep(0.1) + continue + + # Training step + with learner_lock: + c_loss, a_loss, alpha, mean_q, lr = estimator.train_step(batch, config.gradient_clip_norm) + + estimator.update_target_net() + + # Sync actor model + with torch.no_grad(): + for p_learner, p_actor in zip(estimator.net.parameters(), actor_model.parameters()): + p_actor.data.copy_(p_learner.data) + + frames_counter.value += config.batch_size + + if frames_counter.value - last_log_frame >= config.log_every_frames: + latest_critic_loss.value = c_loss + latest_actor_loss.value = a_loss + latest_alpha.value = alpha + latest_mean_q.value = mean_q + latest_lr.value = lr + + log_queue.put({ + 'type': 'train_stats', + 'frames': frames_counter.value, + 'critic_loss': c_loss, + 'actor_loss': a_loss, + 'mean_q': mean_q, + 'alpha': alpha, + 'lr': lr + }) + last_log_frame = frames_counter.value + + +class SACTrainer: + """ + Trainer for the SAC agent. + """ + + def __init__(self, config: TrainerConfig): + """ + Initialized SACTrainer. + """ + + self.config = config + self.plogger = FileWriter(xpid=config.xpid, rootdir=config.savedir, config=self.config) + self.writer = None + + if config.log_to_tensorboard: + tb_dir = os.path.join(config.savedir, config.xpid, 'tensorboard_logs') + self.writer = SummaryWriter(log_dir=tb_dir) + log.info(f"TensorBoard logging to {tb_dir}") + + self.checkpointpath = os.path.join(os.path.expandvars( + os.path.expanduser(config.savedir)), config.xpid, "model.tar") + + self.shutdown_event = threading.Event() + self.evaluator = AgentEvaluator(self.config) + + self.actor_processes = [] + self.learner_thread = None + self.logger = None + self.ingest_thread = None + + def _setup_components(self): + """ + Sets up multiprocessing components. + """ + + cfg = self.config + self.ctx = mp.get_context('spawn') + log.info(f"Using learner device: {cfg.device}") + + self.estimator = SACEstimator(cfg.model_config, cfg, device=cfg.device) + + self.actor_model = BauernskatNet(cfg.model_config).to('cpu') + self.actor_model.share_memory() + self.actor_model.eval() + + self.sample_queue = self.ctx.Queue(maxsize=cfg.num_actors * cfg.actor_queue_size_multiplier) + self.log_queue = self.ctx.Queue() + + self.frames = self.ctx.Value('Q', 0) + self.learner_lock = self.ctx.Lock() + self.buffer_lock = self.ctx.Lock() + self.avg_p0_payoff = self.ctx.Value('f', 0.0) + self.dropped_batches_total = self.ctx.Value('Q', 0) + self.total_elapsed_time = self.ctx.Value('d', 0.0) + + # Logging shared variables + self.latest_critic_loss = self.ctx.Value('f', 0.0) + self.latest_actor_loss = self.ctx.Value('f', 0.0) + self.latest_alpha = self.ctx.Value('f', cfg.initial_alpha) + self.latest_mean_q = self.ctx.Value('f', 0.0) + self.latest_lr = self.ctx.Value('f', cfg.critic_lr) + + self.current_teacher_eps = self.ctx.Value('f', cfg.teacher_start if cfg.use_teacher_forcing else 0.0) + self.current_trump_prob = self.ctx.Value('f', cfg.trump_start if cfg.use_rule_based_trump_decay else 0.0) + + sampler = PrioritizedSampler(max_capacity=cfg.replay_buffer_size, alpha=cfg.per_alpha, beta=cfg.per_beta) + self.replay_buffer = TensorDictReplayBuffer( + storage=LazyTensorStorage(max_size=cfg.replay_buffer_size), + sampler=sampler, batch_size=cfg.batch_size) + + log.info("Seeding replay buffer schema...") + + try: + obs_preprocessor = ObsPreprocessor() + temp_env = rlcard.make(cfg.env, config={'seed': 999}) + temp_agent = AgentSAC_Actor(self.actor_model, 'cpu') + state, _ = temp_env.reset() + act_id, legal_keys = temp_agent.step(state, temp_env) + act_idx = 32 + act_id if act_id < 5 else act_id - 5 + + dummy_sample = { + "observation": state['obs'], + "action": [act_idx], + "legal_keys": legal_keys, + "next": { + "observation": state['obs'], + "reward": [0.0], + "done": [False], + "legal_keys": legal_keys + } + } + self.replay_buffer.add(TensorDict(obs_preprocessor(dummy_sample), batch_size=[])) + self.replay_buffer.empty() + + log.info("Replay buffer seeded.") + except Exception as e: + log.error(f"Failed to seed buffer: {e}") + raise + + # Load model + if cfg.load_model and os.path.exists(self.checkpointpath): + log.info(f"Loading checkpoint from {self.checkpointpath}") + checkpoint = torch.load(self.checkpointpath, map_location=cfg.device, weights_only=False) + + self.estimator.net.load_state_dict(checkpoint['model_state_dict']) + self.estimator.target_net.load_state_dict(checkpoint['target_state_dict']) + self.estimator.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + + if 'alpha_log' in checkpoint and hasattr(self.estimator, 'log_alpha'): + self.estimator.log_alpha.data = torch.tensor([checkpoint['alpha_log']], device=cfg.device) + self.estimator.alpha = self.estimator.log_alpha.exp() + if hasattr(self.estimator, 'alpha_optim') and 'alpha_optim_state_dict' in checkpoint: + self.estimator.alpha_optim.load_state_dict(checkpoint['alpha_optim_state_dict']) + + if hasattr(self.estimator, 'scheduler') and 'scheduler_state_dict' in checkpoint: + self.estimator.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) + + self.frames.value = checkpoint.get('frames', 0) + self.total_elapsed_time.value = checkpoint.get('total_elapsed_time', 0.0) + self.avg_p0_payoff.value = checkpoint.get('avg_p0_payoff', 0.0) + + if 'rng_states' in checkpoint: + rng_states = checkpoint['rng_states'] + try: + torch.set_rng_state(rng_states['torch'].cpu()) + if torch.cuda.is_available() and rng_states['cuda'] is not None: + torch.cuda.set_rng_state_all(rng_states['cuda']) + np.random.set_state(rng_states['numpy']) + random.setstate(rng_states['python']) + log.info("RNG states restored.") + except Exception as e: + log.warning(f"Failed to restore RNG states: {e}") + + log.info(f"Resumed from {self.frames.value} frames.") + + # Initial sync of actor model + with torch.no_grad(): + for p_l, p_a in zip(self.estimator.net.parameters(), self.actor_model.parameters()): + p_a.data.copy_(p_l.data) + + def _sample_ingest_worker(self): + """ + Ingests samples from the sample queue into the replay buffer. + """ + + log.info("Sample ingest worker started.") + + while not self.shutdown_event.is_set(): + try: + batch = self.sample_queue.get(timeout=1.0) + if batch is None: break + + with self.buffer_lock: + for s in batch: + self.replay_buffer.add(TensorDict(s, batch_size=[])) + except queue.Empty: + continue + except (KeyboardInterrupt, EOFError): + break + + def start(self): + """ + Starts the training process. + """ + + self._setup_components() + + self.logger = TrainingLogger(self) + self.logger.start() + + self.ingest_thread = threading.Thread(target=self._sample_ingest_worker, daemon=True) + self.ingest_thread.start() + + self.actor_processes = [ + self.ctx.Process(target=act, args=(i, self.config, self.actor_model, + self.sample_queue, self.log_queue, self.current_trump_prob, self.current_teacher_eps, + self.dropped_batches_total, int(self.frames.value))) + for i in range(self.config.num_actors) + ] + for p in self.actor_processes: p.start() + + self.learner_thread = threading.Thread(target=learn, args=( + self.config, self.estimator, self.actor_model, self.replay_buffer, + self.frames, self.learner_lock, self.buffer_lock, self.log_queue, + self.latest_critic_loss, self.latest_actor_loss, self.latest_alpha, + self.latest_mean_q, self.latest_lr + )) + self.learner_thread.start() + + try: + last_checkpoint_frame = self.frames.value + last_eval_frame = self.frames.value + resumed_time = self.total_elapsed_time.value + start_time = time.time() + + while self.frames.value < self.config.total_frames: + time.sleep(1) + current_frames = self.frames.value + self.total_elapsed_time.value = resumed_time + (time.time() - start_time) + + # Trump Decay + if self.config.use_rule_based_trump_decay: + ratio = min(1.0, current_frames / self.config.trump_decay_frames) + self.current_trump_prob.value = self.config.trump_start - (self.config.trump_start - self.config.trump_end) * ratio + + # Teacher Forcing Decay + if self.config.use_teacher_forcing: + ratio = min(1.0, current_frames / self.config.teacher_decay_frames) + self.current_teacher_eps.value = self.config.teacher_start - (self.config.teacher_start - self.config.teacher_end) * ratio + + with self.buffer_lock: mem_size = len(self.replay_buffer) + + print(f"\rTime: {format_time(self.total_elapsed_time.value)} | " + f"Step: {current_frames/1e6:.2f}M/{self.config.total_frames/1e6:.1f}M | " + f"Mem: {mem_size/1e3:.1f}k | " + f"Teacher-ε: {self.current_teacher_eps.value:.4f} | " + f"Trump-ε: {self.current_trump_prob.value:.4f} | " + f"Alpha: {self.latest_alpha.value:.4f} | " + f"LR: {self.latest_lr.value:.3e} | " + f"ØQ: {self.latest_mean_q.value:+.4f} | " + f"C-Loss: {self.latest_critic_loss.value:.4f} | " + f"ØPayoff: {self.avg_p0_payoff.value:+.2f}", end="", flush=True) + + if current_frames - last_checkpoint_frame >= self.config.save_every_frames: + self.checkpoint() + last_checkpoint_frame = current_frames + + # Evaluation + if current_frames - last_eval_frame >= self.config.eval_every: + print() + eval_net = copy.deepcopy(self.estimator.net).to('cpu') + self.evaluator.evaluate(eval_net, current_frames, self.writer) + last_eval_frame = current_frames + + except KeyboardInterrupt: + print("\nTraining interrupted.") + finally: + print("\nShutting down...") + + self.shutdown_event.set() + + for p in self.actor_processes: + if p.is_alive(): p.terminate(); p.join(timeout=1.0) + + if self.ingest_thread.is_alive(): self.ingest_thread.join(timeout=1.0) + if self.logger: self.logger.stop() + + self.checkpoint() + + self.plogger.close() + if self.writer: self.writer.close() + + def checkpoint(self): + """ + Saves the current model checkpoint. + """ + + log.info(f"Saving checkpoint to {self.checkpointpath}") + + rng_states = { + 'torch': torch.get_rng_state(), + 'cuda': torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None, + 'numpy': np.random.get_state(), + 'python': random.getstate() + } + + checkpoint = { + 'model_state_dict': self.estimator.net.state_dict(), + 'target_state_dict': self.estimator.target_net.state_dict(), + 'optimizer_state_dict': self.estimator.optimizer.state_dict(), + 'frames': self.frames.value, + 'total_elapsed_time': self.total_elapsed_time.value, + 'avg_p0_payoff': self.avg_p0_payoff.value, + 'rng_states': rng_states, + 'config': self.config + } + + if hasattr(self.estimator, 'log_alpha'): + checkpoint['alpha_log'] = self.estimator.log_alpha.item() + if hasattr(self.estimator, 'alpha_optim'): + checkpoint['alpha_optim_state_dict'] = self.estimator.alpha_optim.state_dict() + if hasattr(self.estimator, 'scheduler'): + checkpoint['scheduler_state_dict'] = self.estimator.scheduler.state_dict() + + torch.save(checkpoint, self.checkpointpath) + + ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + bkp = os.path.join(os.path.dirname(self.checkpointpath), f"model_{ts}_frame{self.frames.value}.tar") + torch.save(checkpoint, bkp) + + inference_checkpoint = { + 'model_state_dict': self.estimator.net.state_dict(), + 'config': self.config + } + inf_path = os.path.join(os.path.dirname(self.checkpointpath), "inference_model.pt") + inf_bkp_path = os.path.join(os.path.dirname(self.checkpointpath), f"inference_model_{ts}_frame{self.frames.value}.pt") + + torch.save(inference_checkpoint, inf_path) + torch.save(inference_checkpoint, inf_bkp_path) + log.info(f"Saved inference checkpoint to {inf_path}") + + +def main(): + """ + Main function to run the SAC trainer. + """ + + try: + mp.set_start_method('spawn', force=True) + except RuntimeError: + pass + + setup_logging() + parser = argparse.ArgumentParser("Agent SAC Trainer for RLCard") + + for field in dataclasses.fields(TrainerConfig): + if not field.init or field.name == "model_config": continue + + if field.type == bool: + if field.default: + parser.add_argument(f'--no-{field.name}', dest=field.name, action='store_false') + else: + parser.add_argument(f'--{field.name}', dest=field.name, action='store_true') + + parser.set_defaults(**{field.name: field.default}) + else: + kwargs = {'type': field.type, 'default': field.default} + + if get_origin(field.type) is Literal: + kwargs['choices'] = get_args(field.type) + kwargs['type'] = type(kwargs['choices'][0]) + + parser.add_argument(f'--{field.name}', **kwargs) + + args = parser.parse_args() + config = TrainerConfig(**vars(args)) + os.environ["CUDA_VISIBLE_DEVICES"] = config.cuda + + trainer = SACTrainer(config) + log.info(f"Starting training for {config.xpid} with config:\n{pprint.pformat(dataclasses.asdict(config))}") + trainer.start() + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/rlcard/agents/bauernskat/sac_agent/utils.py b/rlcard/agents/bauernskat/sac_agent/utils.py new file mode 100644 index 000000000..dde1b0bc1 --- /dev/null +++ b/rlcard/agents/bauernskat/sac_agent/utils.py @@ -0,0 +1,365 @@ +''' + File name: rlcard/games/bauernskat/sac_agent/utils.py + Author: Oliver Czerwinski + Date created: 11/10/2025 + Date last modified: 12/26/2025 + Python Version: 3.9+ +''' + +import logging +import time +import queue +import threading +import numpy as np +from typing import Dict, Any, List + +import rlcard +from rlcard.agents.bauernskat import rule_agents as bauernskat_rule_agents +from rlcard.agents.bauernskat.sac_agent.agent import AgentSAC_Actor +from rlcard.agents.bauernskat.sac_agent.config import MAX_TRICK_SIZE, MAX_CEMETERY_SIZE + +def setup_logging(level=logging.INFO): + """ + Prepares the logging. + """ + + logger = logging.getLogger('agent_sac_trainer') + if logger.hasHandlers(): + return + + logger.setLevel(level) + shandle = logging.StreamHandler() + shandle.setFormatter( + logging.Formatter( + '[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] ' + '%(message)s')) + logger.addHandler(shandle) + logger.propagate = False + +class ObsPreprocessor: + """ + Preprocesses observations for the SAC agent. + """ + + def __init__(self): + """ + Initialized ObsPreprocessor. + """ + + self.pad_keys = { + 'trick_card_ids': MAX_TRICK_SIZE, + 'cemetery_card_ids': MAX_CEMETERY_SIZE, + } + + def _pad_obs(self, obs_dict: Dict) -> Dict: + """ + Pads observation of different lengths to a fixed size. + """ + + padded_dict = obs_dict.copy() + + for key, max_len in self.pad_keys.items(): + if key in padded_dict: + original_list = padded_dict[key] + padding_needed = max_len - len(original_list) + if padding_needed > 0: + padded_dict[key] = original_list + [-1] * padding_needed + + return padded_dict + + def _prepare_for_tensordict(self, data: Dict) -> Dict: + """ + Converts a list to a numpy array for tensordict compatibility. + """ + + for key, value in data.items(): + if isinstance(value, dict): + self._prepare_for_tensordict(value) + elif isinstance(value, list): + data[key] = np.array(value, dtype=np.int32) + elif isinstance(value, np.ndarray) and value.dtype == np.float64: + data[key] = value.astype(np.float32) + + return data + + def _get_action_index(self, action_id: int) -> int: + """ + Maps action IDs to the legal actions mask. + """ + + if action_id < 5: + return 32 + action_id + else: + return action_id - 5 + + def _generate_legal_mask(self, legal_actions_keys: List[int]) -> np.ndarray: + """ + Generates a legal actions mask from all actions. + """ + + mask = np.zeros(38, dtype=bool) + for act_id in legal_actions_keys: + idx = self._get_action_index(act_id) + if 0 <= idx < 38: + mask[idx] = True + + return mask + + def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """ + Preprocesses a sample dictionary. + """ + + if 'observation' in sample: + sample['observation'] = self._pad_obs(sample['observation']) + if 'next' in sample and 'observation' in sample['next']: + sample['next']['observation'] = self._pad_obs(sample['next']['observation']) + + if 'legal_keys' in sample: + sample['legal_actions_mask'] = self._generate_legal_mask(sample['legal_keys']) + del sample['legal_keys'] + + if 'next' in sample and 'legal_keys' in sample['next']: + sample['next']['legal_actions_mask'] = self._generate_legal_mask(sample['next']['legal_keys']) + del sample['next']['legal_keys'] + + return self._prepare_for_tensordict(sample) + + +class TrainingLogger: + """ + Handles logging during training. + """ + + def __init__(self, trainer_instance): + """ + Initializes TrainingLogger. + """ + + self.config = trainer_instance.config + self.plogger = trainer_instance.plogger + self.writer = trainer_instance.writer + self.log_queue = trainer_instance.log_queue + self.shutdown_event = trainer_instance.shutdown_event + self.replay_buffer = trainer_instance.replay_buffer + self.buffer_lock = trainer_instance.buffer_lock + self.avg_p0_payoff = trainer_instance.avg_p0_payoff + self.dropped_batches_total = trainer_instance.dropped_batches_total + self.total_elapsed_time = trainer_instance.total_elapsed_time + + self.current_teacher_eps = trainer_instance.current_teacher_eps + self.current_trump_prob = trainer_instance.current_trump_prob + + self.thread = None + self.log = logging.getLogger('agent_sac_trainer') + + def start(self): + """ + Starts a logging thread. + """ + + self.thread = threading.Thread(target=self._log_worker, daemon=True) + self.thread.start() + + def stop(self): + """ + Stops the logging thread. + """ + + if self.thread and self.thread.is_alive(): + self.thread.join(timeout=self.config.process_join_timeout) + + def _log_worker(self): + """ + Processes log records from a queue. + """ + + payoff_buffer = [] + last_log_time = time.time() + + stats = { + 'frames': 0, + 'critic_loss': 0.0, + 'actor_loss': 0.0, + 'alpha': 0.0, + 'mean_q': 0.0, + 'lr': 0.0 + } + + while not self.shutdown_event.is_set(): + try: + record = self.log_queue.get(timeout=1.0) + if record is None: break + + if record.get('type') == 'train_stats': + stats.update(record) + else: + payoff_buffer.append(record['p0_payoff']) + + except queue.Empty: + continue + except (KeyboardInterrupt, EOFError): + break + + if time.time() - last_log_time >= self.config.log_interval_seconds: + if stats['frames'] > 0: + with self.buffer_lock: + buffer_size = len(self.replay_buffer) + + log_data = { + 'Training/frames': stats['frames'], + 'Training/loss_critic': stats['critic_loss'], + 'Training/loss_actor': stats['actor_loss'], + 'Training/mean_q_values': stats['mean_q'], + 'Training/alpha': stats['alpha'], + 'Training/learning_rate': stats['lr'], + + 'Performance/buffer_size': buffer_size, + 'Performance/total_dropped_batches': self.dropped_batches_total.value, + 'Performance/total_training_time_hours': self.total_elapsed_time.value / 3600.0, + 'Performance/avg_p0_payoff_5s': self.avg_p0_payoff.value, + + 'Exploration/Teacher-epsilon': self.current_teacher_eps.value, + 'Exploration/Trump-epsilon': self.current_trump_prob.value, + } + + if payoff_buffer: + avg = np.mean(payoff_buffer) + self.avg_p0_payoff.value = avg + log_data['Performance/avg_p0_payoff_5s'] = avg + log_data['Performance/total_games_in_5s'] = len(payoff_buffer) + payoff_buffer = [] + + self.plogger.log(log_data) + + if self.writer: + for key, value in log_data.items(): + if key not in ['_tick', '_time']: + self.writer.add_scalar(key, value, stats['frames']) + + last_log_time = time.time() + + self.log.info("Log worker terminated.") + + +class AgentEvaluator: + """ + Evaluates the agent against multiple rule-based agents. + """ + + def __init__(self, config): + """ + Initializes AgentEvaluator. + """ + + self.config = config + self.log = logging.getLogger('agent_sac_trainer') + self.opponents = { + 'Random': bauernskat_rule_agents.BauernskatRandomRuleAgent(), + 'Frugal': bauernskat_rule_agents.BauernskatFrugalRuleAgent(), + 'Lookahead': bauernskat_rule_agents.BauernskatLookaheadRuleAgent(), + 'SHOT': bauernskat_rule_agents.BauernskatSHOTAlphaBetaRuleAgent() + } + + def evaluate(self, eval_net, current_frames, writer): + """ + Evaluates the agent against all opponents. + """ + + self.log.info("Starting Evaluation Run...") + + eval_agent = AgentSAC_Actor(eval_net, 'cpu', use_teacher=False) + + eval_env_config = { + 'seed': 500, + 'information_level': self.config.information_level + } + eval_env = rlcard.make(self.config.env, config=eval_env_config) + + total_p0_wins, total_p1_wins = 0, 0 + total_p0_payoff, total_p1_payoff = 0.0, 0.0 + total_games_as_p0, total_games_as_p1 = 0, 0 + + win_rates_by_opponent = {} + avg_rewards_by_opponent = {} + + for name, opponent in self.opponents.items(): + games_per_opponent_half = self.config.num_eval_games // (2 * len(self.opponents)) + + eval_env.set_agents([eval_agent, opponent]) + p0_wins, p0_payoff = self._run_half(eval_env, games_per_opponent_half, agent_pos=0) + + eval_env.set_agents([opponent, eval_agent]) + p1_wins, p1_payoff = self._run_half(eval_env, games_per_opponent_half, agent_pos=1) + + # Accumulate stats + total_p0_wins += p0_wins + total_p0_payoff += p0_payoff + total_games_as_p0 += games_per_opponent_half + + total_p1_wins += p1_wins + total_p1_payoff += p1_payoff + total_games_as_p1 += games_per_opponent_half + + # Individual Opponent Stats + p0_win_rate = p0_wins / games_per_opponent_half if games_per_opponent_half > 0 else 0 + p0_avg_reward = p0_payoff / games_per_opponent_half if games_per_opponent_half > 0 else 0 + p1_win_rate = p1_wins / games_per_opponent_half if games_per_opponent_half > 0 else 0 + p1_avg_reward = p1_payoff / games_per_opponent_half if games_per_opponent_half > 0 else 0 + + self.log.info(f" vs {name}: P0 [WR: {p0_win_rate:.1%}, AvgR: {p0_avg_reward:+.2f}] | P1 [WR: {p1_win_rate:.1%}, AvgR: {p1_avg_reward:+.2f}]") + + # Combined Stats for specific opponent + total_games_this_opponent = games_per_opponent_half * 2 + if total_games_this_opponent > 0: + combined_win_rate = (p0_wins + p1_wins) / total_games_this_opponent + combined_avg_reward = (p0_payoff + p1_payoff) / total_games_this_opponent + win_rates_by_opponent[name] = combined_win_rate + avg_rewards_by_opponent[name] = combined_avg_reward + + # Overall Stats + overall_p0_win_rate = total_p0_wins / total_games_as_p0 if total_games_as_p0 > 0 else 0 + overall_p1_win_rate = total_p1_wins / total_games_as_p1 if total_games_as_p1 > 0 else 0 + overall_p0_avg_reward = total_p0_payoff / total_games_as_p0 if total_games_as_p0 > 0 else 0.0 + overall_p1_avg_reward = total_p1_payoff / total_games_as_p1 if total_games_as_p1 > 0 else 0.0 + + self.log.info(f"Overall Factual -> P0 [WR: {overall_p0_win_rate:.1%}, AvgR: {overall_p0_avg_reward:+.2f}] | P1 [WR: {overall_p1_win_rate:.1%}, AvgR: {overall_p1_avg_reward:+.2f}]") + + if writer: + writer.add_scalar('Evaluation/overall_p0_win_rate', overall_p0_win_rate, current_frames) + writer.add_scalar('Evaluation/overall_p1_win_rate', overall_p1_win_rate, current_frames) + writer.add_scalar('Evaluation/overall_p0_avg_reward', overall_p0_avg_reward, current_frames) + writer.add_scalar('Evaluation/overall_p1_avg_reward', overall_p1_avg_reward, current_frames) + + if win_rates_by_opponent: + writer.add_scalars('Evaluation/Combined_WinRate_vs_Opponent', win_rates_by_opponent, current_frames) + + if avg_rewards_by_opponent: + writer.add_scalars('Evaluation/Combined_AvgR_vs_Opponent', avg_rewards_by_opponent, current_frames) + + def _run_half(self, env, num_games, agent_pos): + """ + Runs games in a specified player role. + """ + + total_wins = 0 + total_payoff = 0.0 + agent = env.agents[agent_pos] + opponent = env.agents[1 - agent_pos] + + for _ in range(num_games): + state, player_id = env.reset() + + while not env.is_over(): + if player_id == agent_pos: + action, _ = agent.eval_step(state, env) + else: + action = opponent.step(state) + state, player_id = env.step(action) + + payoffs = env.get_payoffs() + total_payoff += payoffs[agent_pos] + if payoffs[agent_pos] > 0: + total_wins += 1 + + return total_wins, total_payoff \ No newline at end of file diff --git a/rlcard/envs/__init__.py b/rlcard/envs/__init__.py index de9dbb8c1..96597f506 100644 --- a/rlcard/envs/__init__.py +++ b/rlcard/envs/__init__.py @@ -47,3 +47,8 @@ env_id='bridge', entry_point='rlcard.envs.bridge:BridgeEnv', ) + +register( + env_id='bauernskat', + entry_point='rlcard.envs.bauernskat:BauernskatEnv', +) \ No newline at end of file diff --git a/rlcard/envs/bauernskat.py b/rlcard/envs/bauernskat.py new file mode 100644 index 000000000..439ec0523 --- /dev/null +++ b/rlcard/envs/bauernskat.py @@ -0,0 +1,198 @@ +''' + File name: rlcard/envs/bauernskat.py + Author: Oliver Czerwinski + Date created: 08/02/2025 + Date last modified: 25/12/2025 + Python Version: 3.9+ +''' + +import copy +import numpy as np + +from rlcard.envs import Env +from rlcard.games.bauernskat.game import BauernskatGame as Game +from rlcard.games.bauernskat.action_event import ActionEvent, DeclareTrumpAction +from rlcard.games.bauernskat import config + + +class BauernskatEnv(Env): + """ + Bauernskat Environment wrapper for RLCard. + """ + + def __init__(self, config=None): + """ + Inititialized BauernskatEnv. + """ + + if config is None: + config = {} + self.name = 'bauernskat' + self.game = Game(information_level=config.get('information_level', 'normal')) + super().__init__(config) + + self.state_shape = {} + self.action_shape = [None for _ in range(self.num_players)] + + def seed(self, seed: int) -> None: + """ + Sets a seed. + """ + + self.game.np_random = np.random.RandomState(seed) + + def _get_legal_actions(self): + """ + Gets the legal actions from judger. + """ + + legal_actions = self.game.judger.get_legal_actions() + return {action.action_id: True for action in legal_actions} + + def _extract_state(self, state): + """ + Extracts state representation from the game state as a dictionary. + """ + + raw_info = state['raw_state_info'] + obs = {} + + # Layouts: (8, 2) tensor [open_card_id, hidden_card_id] + my_layout_tensor = np.full((config.NUM_COLUMNS_PER_PLAYER, 2), 32, dtype=np.int32) + for i, col in enumerate(raw_info['my_layout']): + if col.open_card: + my_layout_tensor[i, 0] = col.open_card.card_id + + # The hidden card IDs dependent on information level + if col.closed_card and raw_info.get('my_hidden_cards'): + my_layout_tensor[i, 1] = col.closed_card.card_id + obs['my_layout_tensor'] = my_layout_tensor + + # Opponent layout + opponent_layout_tensor = np.full((config.NUM_COLUMNS_PER_PLAYER, 2), 32, dtype=np.int32) + for i, col in enumerate(raw_info['opponent_layout']): + if col.open_card: + opponent_layout_tensor[i, 0] = col.open_card.card_id + + # The hidden card IDs dependent on information level + if col.closed_card and raw_info.get('opponent_hidden_cards'): + opponent_layout_tensor[i, 1] = col.closed_card.card_id + obs['opponent_layout_tensor'] = opponent_layout_tensor + + # Unaccounted card mask: 32 vector + known_card_ids = set(card.card_id for card in raw_info['played_cards']) + known_card_ids.update(card.card_id for _, card in raw_info['trick_moves']) + + known_card_ids.update(my_layout_tensor[my_layout_tensor != 32]) + known_card_ids.update(opponent_layout_tensor[opponent_layout_tensor != 32]) + + unaccounted_mask = np.ones(32, dtype=np.float32) + for card_id in known_card_ids: + if 0 <= card_id < 32: + unaccounted_mask[card_id] = 0.0 + obs['unaccounted_cards_mask'] = unaccounted_mask + + # Current trick and cementery card IDs: lists + obs['trick_card_ids'] = [card.card_id for _, card in raw_info['trick_moves']] + obs['cemetery_card_ids'] = [card.card_id for card in raw_info['played_cards']] + + # Hidden Card Indicators: 8 vector + my_hidden = np.zeros(config.NUM_COLUMNS_PER_PLAYER, dtype=np.float32) + for i, col in enumerate(raw_info['my_layout']): + if col.has_card_underneath(): + my_hidden[i] = 1.0 + obs['my_hidden_indicators'] = my_hidden + + opponent_hidden = np.zeros(config.NUM_COLUMNS_PER_PLAYER, dtype=np.float32) + for i, col in enumerate(raw_info['opponent_layout']): + if col.has_card_underneath(): + opponent_hidden[i] = 1.0 + obs['opponent_hidden_indicators'] = opponent_hidden + + # Normalized Context Feature Vector (11,) + MAX_SCORE = 480.0 + MAX_TRICKS = 16.0 + + context = np.zeros(11, dtype=np.float32) + + if raw_info['trump_suit'] is not None: + trump_idx = DeclareTrumpAction.VALID_TRUMPS.index(raw_info['trump_suit']) + context[trump_idx] = 1.0 + + context[5] = 1.0 if raw_info['round_phase'] == 'play' else 0.0 + context[6] = float(raw_info['current_player_id']) + context[7] = float(raw_info['trick_leader_id']) + context[8] = np.clip(float(raw_info['my_score']) / MAX_SCORE, 0.0, 1.0) + context[9] = np.clip(float(raw_info['opponent_score']) / MAX_SCORE, 0.0, 1.0) + context[10] = float(raw_info['tricks_played']) / MAX_TRICKS + + obs['context'] = context + + # Padded Action History Tensor + history_tensor = np.zeros((config.HISTORY_SEQUENCE_LENGTH, config.HISTORY_FRAME_SIZE), dtype=np.float32) + history_frames = raw_info['history_frames'] + if history_frames: + num_frames = len(history_frames) + if num_frames > 0: + history_matrix = np.vstack(history_frames) + + # Pad at the beginning of the tensor + history_tensor[-num_frames:] = history_matrix + + obs['action_history'] = history_tensor + + return { + 'obs': obs, + 'legal_actions': self._get_legal_actions(), + 'raw_state_info': raw_info, + 'raw_legal_actions': list(self._get_legal_actions().keys()), + 'action_record': self.action_recorder, + } + + def get_payoffs(self): + """ + Gets the payoffs of the game. + """ + + return np.array(self.game.get_payoffs(), dtype=np.float32) + + def get_scores(self): + """ + Gets the scores of the game. + """ + + return np.array([p.score for p in self.game.players], dtype=np.int32) + + def _decode_action(self, action_id): + """ + Decodes an action ID into ActionEvent. + """ + + return ActionEvent.from_action_id(action_id) + + def get_perfect_information(self): + """ + Gets the perfect information of the game. + """ + + p0 = self.game.players[0] + p1 = self.game.players[1] + + return { + 'player_0_layout_open': [c.card_id for c in p0.get_playable_cards()], + 'player_0_layout_hidden': [c.card_id for c in p0.get_hidden_cards()], + 'player_1_layout_open': [c.card_id for c in p1.get_playable_cards()], + 'player_1_layout_hidden': [c.card_id for c in p1.get_hidden_cards()], + 'trump_suit': self.game.round.trump_suit if self.game.round else None, + 'current_phase': self.game.round.round_phase if self.game.round else None, + 'trick_moves': [c.card_id for _, c in (self.game.round.trick_moves if self.game.round else [])] + } + + def clone(self): + """ + Creates a copy of the environment object. + """ + + cloned_env = BauernskatEnv(copy.copy(self.config)) + cloned_env.game = self.game.clone() + return cloned_env \ No newline at end of file diff --git a/rlcard/games/bauernskat/__init__.py b/rlcard/games/bauernskat/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/rlcard/games/bauernskat/action_event.py b/rlcard/games/bauernskat/action_event.py new file mode 100644 index 000000000..082c8a52c --- /dev/null +++ b/rlcard/games/bauernskat/action_event.py @@ -0,0 +1,130 @@ +''' + File name: rlcard/games/bauernskat/action_event.py + Author: Oliver Czerwinski + Date created: 07/17/2025 + Date last modified: 12/25/2025 + Python Version: 3.9+ +''' + +from . import config +from .card import BauernskatCard + + +class ActionEvent: + """ + A base class for actions that can occur in Bauernskat. + """ + + def __init__(self, action_id: int) -> None: + """ + Initializes ActionEvent. + """ + self.action_id: int = action_id + + def __eq__(self, other: object) -> bool: + """ + Equality check based on action_id. + """ + if not isinstance(other, ActionEvent): + return NotImplemented + return self.action_id == other.action_id + + def __hash__(self) -> int: + """ + Hash based on action ID. + """ + return hash(self.action_id) + + @staticmethod + def from_action_id(action_id: int) -> 'ActionEvent': + """ + Creates an action event based on its ID. + """ + if config.FIRST_DECLARE_TRUMP_ACTION_ID <= action_id < config.FIRST_PLAY_CARD_ACTION_ID: + return DeclareTrumpAction.from_action_id(action_id) + + if config.FIRST_PLAY_CARD_ACTION_ID <= action_id < config.TOTAL_NUM_ACTIONS: + card_id = action_id - config.FIRST_PLAY_CARD_ACTION_ID + card = BauernskatCard.card(card_id=card_id) + return PlayCardAction(card=card) + + raise ValueError(f"Invalid action_id {action_id}. Must be between 0 and {config.TOTAL_NUM_ACTIONS - 1}.") + + @staticmethod + def get_num_actions() -> int: + """ + Returns the number of possible actions. + """ + return config.TOTAL_NUM_ACTIONS + + +class DeclareTrumpAction(ActionEvent): + """ + ActionEvent for declaring a trump suit or a grand. + """ + + VALID_TRUMPS: tuple[str, ...] = BauernskatCard.suits + ('G',) + + def __init__(self, trump_suit: str) -> None: + """ + Initializes DeclareTrumpAction. + """ + + if trump_suit not in self.VALID_TRUMPS: + raise ValueError(f"Invalid trump suit '{trump_suit}'. Must be one of {self.VALID_TRUMPS}.") + + self.trump_suit: str = trump_suit + action_id = config.FIRST_DECLARE_TRUMP_ACTION_ID + self.VALID_TRUMPS.index(trump_suit) + + super().__init__(action_id) + + @classmethod + def from_action_id(cls, action_id: int) -> 'DeclareTrumpAction': + """ + Creates a DeclareTrumpAction from the ID. + """ + + trump_index = action_id - config.FIRST_DECLARE_TRUMP_ACTION_ID + trump_suit = cls.VALID_TRUMPS[trump_index] + + return cls(trump_suit) + + def __str__(self) -> str: + """ + String representation of a DeclareTrumpAction. + """ + return f"Declare {self.trump_suit}" + + def __repr__(self) -> str: + """ + Representation of a DeclareTrumpAction object. + """ + return f"DeclareTrumpAction(trump_suit='{self.trump_suit}')" + + +class PlayCardAction(ActionEvent): + """ + ActionEvent for playing a card. + """ + + def __init__(self, card: BauernskatCard) -> None: + """ + Initializes PlayCardAction. + """ + + self.card: BauernskatCard = card + action_id = config.FIRST_PLAY_CARD_ACTION_ID + card.card_id + + super().__init__(action_id) + + def __str__(self) -> str: + """ + String representation of a PlayCardAction. + """ + return f"Play {self.card}" + + def __repr__(self) -> str: + """ + Representation of a PlayCardAction object. + """ + return f"PlayCardAction(card={repr(self.card)})" \ No newline at end of file diff --git a/rlcard/games/bauernskat/card.py b/rlcard/games/bauernskat/card.py new file mode 100644 index 000000000..0b7744b01 --- /dev/null +++ b/rlcard/games/bauernskat/card.py @@ -0,0 +1,83 @@ +''' + File name: rlcard/games/bauernskat/card.py + Author: Oliver Czerwinski + Date created: 17/07/2025 + Date last modified: 25/12/2025 + Python Version: 3.9+ +''' + +from . import config +from rlcard.games.base import Card + + +class BauernskatCard(Card): + """ + BauernskatCard implements the properties of a card in a 32-card Skat deck. + """ + + suits: tuple[str, ...] = config.VALID_SUITS + ranks: tuple[str, ...] = config.VALID_RANKS + + def __init__(self, suit: str, rank: str) -> None: + """ + Initializes a BauernskatCard. + """ + super().__init__(suit, rank) + + if suit not in self.suits: + raise ValueError(f"Invalid suit '{suit}'. Must be one of {self.suits}.") + if rank not in self.ranks: + raise ValueError(f"Invalid rank '{rank}'. Must be one of {self.ranks}.") + + # Calculate IDs for all combinations + suit_index: int = self.suits.index(suit) + rank_index: int = self.ranks.index(rank) + self.card_id: int = suit_index * len(self.ranks) + rank_index + + self.points: int = config.RANK_VALUES[self.rank] + + def __str__(self) -> str: + """ + String representation of a card. + """ + return f'{self.rank}{self.suit}' + + def __repr__(self) -> str: + """ + Representation of card object. + """ + return f"BauernskatCard('{self.suit}', '{self.rank}')" + + def __eq__(self, other: object) -> bool: + """ + Equality check based on card ID. + """ + if not isinstance(other, BauernskatCard): + return NotImplemented + return self.card_id == other.card_id + + def __hash__(self) -> int: + """ + Hash based on card ID. + """ + return hash(self.card_id) + + @staticmethod + def card(card_id: int) -> 'BauernskatCard': + """ + Gets a valid card instance from the deck using the ID. + """ + if not 0 <= card_id < len(_DECK): + raise IndexError(f"card_id {card_id} is out of range. Must be between 0 and 31.") + return _DECK[card_id] + + @staticmethod + def get_deck() -> list['BauernskatCard']: + """ + Returns a copy of the 32-card deck. + """ + return _DECK.copy() + + +# Source deck to only generate once +_DECK: list[BauernskatCard] = [BauernskatCard(suit, rank) for suit in config.VALID_SUITS for rank in config.VALID_RANKS] \ No newline at end of file diff --git a/rlcard/games/bauernskat/config.py b/rlcard/games/bauernskat/config.py new file mode 100644 index 000000000..96dea8352 --- /dev/null +++ b/rlcard/games/bauernskat/config.py @@ -0,0 +1,43 @@ +''' + File name: rlcard/games/bauernskat/config.py + Author: Oliver Czerwinski + Date created: 07/17/2025 + Date last modified: 12/25/2025 +''' + +# Deck Configuration + +# Defines suits and ranks +VALID_SUITS: tuple[str, ...] = ('C', 'S', 'H', 'D') +VALID_RANKS: tuple[str, ...] = ('7', '8', '9', 'Q', 'K', '10', 'A', 'J') + +# Defines pip value of each rank +RANK_VALUES: dict[str, int] = { + '7': 0, '8': 0, '9': 0, 'Q': 3, 'K': 4, '10': 10, 'A': 11, 'J': 2 +} + + +# Player and Layout Configuration + +NUM_PLAYERS: int = 2 +NUM_COLUMNS_PER_PLAYER: int = 8 + + +# Action Space Configuration + +GRAND_AVAILABLE: bool = True + +# Trump actions (4 suits + 1 grand) +NUM_DECLARE_TRUMP_ACTIONS: int = len(VALID_SUITS) + (int(GRAND_AVAILABLE) * 1) + +FIRST_DECLARE_TRUMP_ACTION_ID: int = 0 +FIRST_PLAY_CARD_ACTION_ID: int = FIRST_DECLARE_TRUMP_ACTION_ID + NUM_DECLARE_TRUMP_ACTIONS + +NUM_CARDS_IN_DECK: int = len(VALID_SUITS) * len(VALID_RANKS) +TOTAL_NUM_ACTIONS: int = NUM_DECLARE_TRUMP_ACTIONS + NUM_CARDS_IN_DECK + + +# History Construction Configuration + +HISTORY_SEQUENCE_LENGTH: int = 33 +HISTORY_FRAME_SIZE: int = 49 \ No newline at end of file diff --git a/rlcard/games/bauernskat/dealer.py b/rlcard/games/bauernskat/dealer.py new file mode 100644 index 000000000..f1942ba47 --- /dev/null +++ b/rlcard/games/bauernskat/dealer.py @@ -0,0 +1,89 @@ +''' + File name: rlcard/games/bauernskat/dealer.py + Author: Oliver Czerwinski + Date created: 07/17/2025 + Date last modified: 12/25/2025 + Python Version: 3.9+ +''' + +from typing import List +import numpy as np + +from .player import BauernskatPlayer +from .card import BauernskatCard +from . import config + + +class BauernskatDealer: + """ + The BauernskatDealer shuffles the deck and deals cards in the two phases. + """ + + def __init__(self, np_random: np.random.RandomState) -> None: + """ + Initializes BauernskatDealer. + """ + + self.np_random: np.random.RandomState = np_random + + self.shuffled_deck: List[BauernskatCard] = BauernskatCard.get_deck() + self.np_random.shuffle(self.shuffled_deck) + + self._card_stack: List[BauernskatCard] = self.shuffled_deck.copy() + + def deal_phase_one(self, players: List[BauernskatPlayer]) -> None: + """ + Deals the first 12 cards to set up for trump declaration. + - Deals 4 closed cards to Vorhand. + - Deals 4 closed cards to Geber. + - Deals 4 open cards to Vorhand. + Then the dealer will pause for the Vorhand to declare trump before continuing. + """ + + assert len(players) == config.NUM_PLAYERS, "Dealing requires exactly two players." + + vorhand, geber = players[0], players[1] + num_cols_half = config.NUM_COLUMNS_PER_PLAYER // 2 + + # Deal 4 closed cards to Vorhand + for i in range(0, num_cols_half): + vorhand.layout[i].closed_card = self._card_stack.pop() + + # Deal 4 closed cards to Geber + for i in range(0, num_cols_half): + geber.layout[i].closed_card = self._card_stack.pop() + + # Deal 4 open cards to Vorhand + for i in range(0, num_cols_half): + vorhand.layout[i].open_card = self._card_stack.pop() + + def deal_phase_two(self, players: List[BauernskatPlayer]) -> None: + """ + Deal the remaining 20 cards after the trump has been declared. + """ + assert len(players) == config.NUM_PLAYERS, "Dealing requires exactly two players." + + vorhand, geber = players[0], players[1] + num_cols_half = config.NUM_COLUMNS_PER_PLAYER // 2 + + # Deal 4 open cards to Geber + for i in range(0, num_cols_half): + geber.layout[i].open_card = self._card_stack.pop() + + # Deal 4 closed cards to Vorhand + for i in range(num_cols_half, config.NUM_COLUMNS_PER_PLAYER): + vorhand.layout[i].closed_card = self._card_stack.pop() + + # Deal 4 closed cards to Geber + for i in range(num_cols_half, config.NUM_COLUMNS_PER_PLAYER): + geber.layout[i].closed_card = self._card_stack.pop() + + # Deal 4 open to Vorhand + for i in range(num_cols_half, config.NUM_COLUMNS_PER_PLAYER): + vorhand.layout[i].open_card = self._card_stack.pop() + + # Deal 4 open cards to Geber + for i in range(num_cols_half, config.NUM_COLUMNS_PER_PLAYER): + geber.layout[i].open_card = self._card_stack.pop() + + assert len(self._card_stack) == 0, "All cards should have been dealt." \ No newline at end of file diff --git a/rlcard/games/bauernskat/game.py b/rlcard/games/bauernskat/game.py new file mode 100644 index 000000000..246de040a --- /dev/null +++ b/rlcard/games/bauernskat/game.py @@ -0,0 +1,265 @@ +''' + File name: rlcard/games/bauernskat/game.py + Author: Oliver Czerwinski + Date created: 07/17/2025 + Date last modified: 12/25/2025 + Python Version: 3.9+ +''' + +import copy +from typing import List, Dict, Any +import numpy as np + +from . import config +from .player import BauernskatPlayer +from .dealer import BauernskatDealer +from .round import BauernskatRound +from .judger import BauernskatJudger +from .action_event import ActionEvent, DeclareTrumpAction, PlayCardAction +from .card import BauernskatCard + + +class BauernskatGame: + """ + BauernskatGame runs the game loop and provides information. + """ + + def __init__(self, allow_step_back: bool = False, information_level: str = 'normal') -> None: + """ + Initializes BauernskatGame. + """ + + self.allow_step_back = allow_step_back + self.information_level = information_level + self.np_random: np.random.RandomState = np.random.RandomState() + + self.players: List[BauernskatPlayer] = [] + self.round: BauernskatRound = None + self.judger: BauernskatJudger = BauernskatJudger(game=self) + + def get_num_players(self) -> int: + """ + Returns the number of players. + """ + + return config.NUM_PLAYERS + + @staticmethod + def get_num_actions() -> int: + """ + Returns the total number of unique actions in the game. + """ + + return ActionEvent.get_num_actions() + + def init_game(self) -> tuple[Dict[str, Any], int]: + """ + Initializes a new game of Bauernskat. + """ + + self.players = [BauernskatPlayer(i, self.np_random) for i in range(config.NUM_PLAYERS)] + dealer = BauernskatDealer(self.np_random) + + dealer.deal_phase_one(self.players) + + self.round = BauernskatRound(dealer, self.players, self.np_random) + + current_player_id = self.get_player_id() + state = self.get_state(current_player_id) + + return state, current_player_id + + def step(self, action: ActionEvent) -> tuple[Dict[str, Any], int]: + """ + Executes an action and transitions to the next state. + """ + + if isinstance(action, int): + decoded_action = ActionEvent.from_action_id(action) + else: + decoded_action = action + + if self.is_over(): + raise ValueError("Cannot perform an action in a completed game.") + + if isinstance(decoded_action, DeclareTrumpAction): + self.round.declare_trump(decoded_action) + elif isinstance(decoded_action, PlayCardAction): + self.round.play_card(decoded_action) + + next_player_id = self.get_player_id() + next_state = self.get_state(next_player_id) + + return next_state, next_player_id + + def get_state(self, player_id: int) -> Dict[str, Any]: + """ + Generates a state representation for a specific player. + """ + + player = self.players[player_id] + opponent = self.players[1 - player_id] + + legal_actions = self.judger.get_legal_actions() + legal_actions_dict = {action.action_id: True for action in legal_actions} + + # Contains all dynamic informations about the game state. + raw_state_info = { + 'round_phase': self.round.round_phase if self.round else 'declare_trump', + 'my_cards': player.get_playable_cards(), + 'opponent_visible_cards': opponent.get_playable_cards(), + 'my_layout': player.layout, + 'opponent_layout': opponent.layout, + 'trick_moves': self.round.trick_moves if self.round else [], + 'trump_suit': self.round.trump_suit if self.round else None, + 'current_player_id': self.get_player_id(), + 'trick_leader_id': self.round.trick_leader_id if self.round else 0, + 'tricks_played': self.round.tricks_played if self.round else 0, + 'player_id': player_id, + 'my_score': player.score, + 'opponent_score': opponent.score, + 'played_cards': self.round.played_cards if self.round else set(), + 'history_frames': self.round.history_frames if self.round else [], + } + + # Constraints asymmetric information level for a player. + if isinstance(self.information_level, dict): + current_level = self.information_level.get(player_id, 'normal') + else: + current_level = self.information_level + + # Optionally add hidden card informations based on the information level. + if current_level in ('show_self', 'perfect'): + raw_state_info['my_hidden_cards'] = player.get_hidden_cards() + else: + raw_state_info['my_hidden_cards'] = [] + + if current_level == 'perfect': + raw_state_info['opponent_hidden_cards'] = opponent.get_hidden_cards() + else: + raw_state_info['opponent_hidden_cards'] = [] + + state = { + 'legal_actions': legal_actions_dict, + 'raw_state_info': raw_state_info + } + + if self.is_over(): + p0_score = self.players[0].score + p1_score = 120 - p0_score + + if player_id == 0: + state['raw_state_info']['pip_difference'] = p0_score - p1_score + else: + state['raw_state_info']['pip_difference'] = p1_score - p0_score + + game_payoffs = self.get_payoffs() + state['raw_state_info']['game_value_payoff'] = game_payoffs[player_id] + + return state + + def get_payoffs(self) -> List[float]: + """ + Determines the payoffs at the end of the game based on actual Skat scoring. + """ + + if not self.is_over(): + return [0.0, 0.0] + + # Winner based on pips + p0_score = self.players[0].score + + assert 0 <= p0_score <= 120, f"Invalid score for Player 0: {p0_score}" + + p1_score = 120 - p0_score + + # A tie means the Geber wins + declarer_wins = p0_score >= 61 + + # Base game value + base_values = {'C': 12, 'S': 11, 'H': 10, 'D': 9, 'G': 24} + trump_suit = self.round.trump_suit + base_value = base_values.get( trump_suit, 0) + + # Matador count + shuffled_deck = self.round.dealer.shuffled_deck + p0_card_indices = [ + 31, 30, 29, 28, 23, 22, 21, 20, 15, 14, 13, 12, 7, 6, 5, 4 + ] + p0_initial_hand = {shuffled_deck[i] for i in p0_card_indices} + + jacks_in_order = [ + BauernskatCard('C', 'J'), BauernskatCard('S', 'J'), + BauernskatCard('H', 'J'), BauernskatCard('D', 'J') + ] + + matador_count = 0 + has_club_jack = jacks_in_order[0] in p0_initial_hand + + if has_club_jack: + # With n matadors + for jack in jacks_in_order: + if jack in p0_initial_hand: matador_count += 1 + else: break + else: + # Without n matadors + for jack in jacks_in_order: + if jack not in p0_initial_hand: matador_count += 1 + else: break + + # Base multiplier + base_multiplier = matador_count + 1 + game_value = base_value * base_multiplier + + # Calculate payoffs + if declarer_wins: + final_score = game_value + # Apply Schneider or Schwarz multipliers + if p1_score == 0: + final_score *= 4 + elif p1_score < 31: + final_score *= 2 + else: + # The Vorhand loses twice the game value + final_score = -2 * game_value + + return [float(final_score), float(-final_score)] + + def get_player_id(self) -> int: + """ + Returns the ID of the player of the current turn. + """ + + if self.round is None: + return 0 + + return self.round.current_player_id + + def is_over(self) -> bool: + """ + Checks if the game has ended. + """ + + if self.round is None: + return False + + return self.round.is_over() + + def clone(self): + """ + Creates a deep copy of the game object for simulations. + """ + + cloned_game = BauernskatGame(allow_step_back=self.allow_step_back, information_level=self.information_level) + cloned_game.np_random.set_state(self.np_random.get_state()) + + if self.round: + cloned_game.round = copy.deepcopy(self.round) + cloned_game.players = cloned_game.round.players + else: + cloned_game.players = copy.deepcopy(self.players) + cloned_game.round = None + + cloned_game.judger = BauernskatJudger(game=cloned_game) + + return cloned_game \ No newline at end of file diff --git a/rlcard/games/bauernskat/judger.py b/rlcard/games/bauernskat/judger.py new file mode 100644 index 000000000..7cdc39840 --- /dev/null +++ b/rlcard/games/bauernskat/judger.py @@ -0,0 +1,95 @@ +''' + File name: rlcard/games/bauernskat/judger.py + Author: Oliver Czerwinski + Date created: 07/17/2025 + Date last modified: 12/25/2025 + Python Version: 3.9+ +''' + +from typing import List, TYPE_CHECKING + +from .action_event import ActionEvent, DeclareTrumpAction, PlayCardAction +from .card import BauernskatCard + +if TYPE_CHECKING: + from .game import BauernskatGame + + +class BauernskatJudger: + """ + Determines the set of legal actions for a player at any point of time. + """ + + def __init__(self, game: 'BauernskatGame') -> None: + """ + Initializes BauernskatJudger. + """ + + self.game: 'BauernskatGame' = game + + def _is_trump(self, card: BauernskatCard, trump_suit: str) -> bool: + """ + Determines if a card is a trump card in the game. + - Grand: Only Jacks. + - Color Suit: All Jacks and all cards of the selected suit. + """ + + assert trump_suit is not None, "Trump suit must be declared to check for trumps." + + if trump_suit == 'G': + return card.rank == 'J' + + if card.rank == 'J': + return True + + if card.suit == trump_suit: + return True + + return False + + def get_legal_actions(self) -> List[ActionEvent]: + """ + List of legal actions for the current player. + """ + + round = self.game.round + + if round.is_over(): + return [] + + if round.round_phase == 'declare_trump': + return [DeclareTrumpAction(suit) for suit in DeclareTrumpAction.VALID_TRUMPS] + + if round.round_phase == 'play': + current_player = round.players[round.current_player_id] + playable_cards = current_player.get_playable_cards() + + if not playable_cards: + return [] + + # Player is starting the trick: Any card is legal + if not round.trick_moves: + return [PlayCardAction(card) for card in playable_cards] + + # Player is answering the trick: Must play specific suit/trump if possible + led_card = round.trick_moves[0][1] + trump_suit = round.trump_suit + + # A trump card has been played: Player must answer a trump card if possible + if self._is_trump(led_card, trump_suit): + trumps_in_hand = [card for card in playable_cards if self._is_trump(card, trump_suit)] + if trumps_in_hand: + return [PlayCardAction(card) for card in trumps_in_hand] + + # A non-trump card has been played: Player must answer with that suit if possible + else: + led_suit = led_card.suit + + suit_in_hand = [card for card in playable_cards if card.suit == led_suit and not self._is_trump(card, trump_suit)] + if suit_in_hand: + return [PlayCardAction(card) for card in suit_in_hand] + + # If the player has no fitting cards, they can answer with any other card + return [PlayCardAction(card) for card in playable_cards] + + return [] \ No newline at end of file diff --git a/rlcard/games/bauernskat/player.py b/rlcard/games/bauernskat/player.py new file mode 100644 index 000000000..4e3e8a357 --- /dev/null +++ b/rlcard/games/bauernskat/player.py @@ -0,0 +1,115 @@ +''' + File name: rlcard/games/bauernskat/player.py + Author: Oliver Czerwinski + Date created: 07/17/2025 + Date last modified: 12/25/2025 + Python Version: 3.9+ +''' + +from typing import Optional, List + +from . import config +from .card import BauernskatCard + + +class _CardColumn: + """ + Helper class to represent a single stack of cards on the table. + """ + + def __init__(self) -> None: + """ + Initializes empty card column. + """ + + self.open_card: Optional[BauernskatCard] = None + self.closed_card: Optional[BauernskatCard] = None + + def is_playable(self) -> bool: + """ + A column is playable if it has an open card. + """ + + return self.open_card is not None + + def has_card_underneath(self) -> bool: + """ + Checks if playing the open card will reveal a closed card. + """ + + return self.closed_card is not None + + def play_card(self) -> BauernskatCard: + """ + Removes the open card and moves the closed card to the open position. + """ + + assert self.is_playable(), "Cannot play a card from an empty column." + + played_card = self.open_card + self.open_card = self.closed_card + self.closed_card = None + + return played_card + + def __repr__(self) -> str: + """ + Representation of the card column. + """ + return f"_CardColumn(open={self.open_card}, closed={self.closed_card is not None})" + + +class BauernskatPlayer: + """ + Manages the state of one player in the game. + """ + + def __init__(self, player_id: int, np_random) -> None: + """ + Initializes BauernskatPlayer. + """ + + if player_id not in {0, 1}: + raise ValueError(f"Invalid player_id '{player_id}'. Must be 0 or 1.") + + self.player_id: int = player_id + self.np_random = np_random + self.score: int = 0 + self.layout: List[_CardColumn] = [_CardColumn() for _ in range(config.NUM_COLUMNS_PER_PLAYER)] + + def get_playable_cards(self) -> List[BauernskatCard]: + """ + Returns a list of all cards that are open and can be played. + """ + + return [col.open_card for col in self.layout if col.is_playable()] + + def get_hidden_cards(self) -> List[BauernskatCard]: + """ + Returns a list of all cards that are closed. + """ + + return [col.closed_card for col in self.layout if col.has_card_underneath()] + + def find_column_for_card(self, card_to_find: BauernskatCard) -> Optional[_CardColumn]: + """ + Finds the _CardColumn that currently holds the specific card. + """ + + for column in self.layout: + if column.is_playable() and column.open_card == card_to_find: + return column + + return None + + def add_points(self, points: int) -> None: + """ + Adds points from a won trick to the player score. + """ + self.score += points + + def __str__(self) -> str: + """ + String representation of the player. + """ + return f"Player {self.player_id}" \ No newline at end of file diff --git a/rlcard/games/bauernskat/round.py b/rlcard/games/bauernskat/round.py new file mode 100644 index 000000000..c0130ef82 --- /dev/null +++ b/rlcard/games/bauernskat/round.py @@ -0,0 +1,192 @@ +''' + File name: rlcard/games/bauernskat/round.py + Author: Oliver Czerwinski + Date created: 07/17/2025 + Date last modified: 12/25/2025 + Python Version: 3.9+ +''' + +from typing import List, Optional, Tuple +import numpy as np + +from . import config +from .player import BauernskatPlayer +from .dealer import BauernskatDealer +from .card import BauernskatCard +from .action_event import ActionEvent, DeclareTrumpAction, PlayCardAction + + +class BauernskatRound: + """ + Manages the state and progression of a single round in Bauernskat. + """ + + def __init__(self, dealer: BauernskatDealer, players: List[BauernskatPlayer], np_random: np.random.RandomState) -> None: + """ + Initializes BauernskatRound. + """ + + self.np_random: np.random.RandomState = np_random + self.dealer: BauernskatDealer = dealer + self.players: List[BauernskatPlayer] = players + + self.round_phase: str = 'declare_trump' + self.current_player_id: int = 0 + self.trick_leader_id: int = 0 + self.trump_suit: Optional[str] = None + + self.trick_moves: List[Tuple[int, BauernskatCard]] = [] + self.tricks_played: int = 0 + self.played_cards: set[BauernskatCard] = set() + + self.history_frames: List[np.ndarray] = [] + + def _create_history_frame(self, action: ActionEvent) -> np.ndarray: + """ + Creates a history frame vector based on the current game state and taken action. + """ + acting_player_id = self.current_player_id + + # Encode Player + player_vec = np.zeros(2, dtype=np.float32) + player_vec[acting_player_id] = 1 + + # Encode Action + action_vec = np.zeros(37, dtype=np.float32) + if isinstance(action, DeclareTrumpAction): + trump_idx = DeclareTrumpAction.VALID_TRUMPS.index(action.trump_suit) + action_vec[32 + trump_idx] = 1 + elif isinstance(action, PlayCardAction): + suit_idx = BauernskatCard.suits.index(action.card.suit) + rank_idx = BauernskatCard.ranks.index(action.card.rank) + action_vec[suit_idx * 8 + rank_idx] = 1 + + # Encode Context + # Normalized pip scores + my_score = self.players[acting_player_id].score + opp_score = self.players[1 - acting_player_id].score + + # Trump suit + trump_vec = np.zeros(5, dtype=np.float32) + if self.trump_suit is not None: + trump_idx = DeclareTrumpAction.VALID_TRUMPS.index(self.trump_suit) + trump_vec[trump_idx] = 1 + + # Trick Leader + leader_vec = np.zeros(2, dtype=np.float32) + leader_vec[self.trick_leader_id] = 1 + + # Normalized count of tricks played + tricks_played_count = self.tricks_played + + context_vec = np.concatenate([ + [my_score / 120.0, opp_score / 120.0], + trump_vec, + leader_vec, + [tricks_played_count / 16.0] + ]) + + # Concatenated frame vector + final_frame = np.concatenate([player_vec, action_vec, context_vec]) + + assert final_frame.shape[0] == config.HISTORY_FRAME_SIZE + + return final_frame + + def declare_trump(self, action: DeclareTrumpAction) -> None: + """ + Processes a DeclareTrumpAction and records it in the history. + """ + + assert self.round_phase == 'declare_trump', "Trump can only be declared once." + assert self.current_player_id == 0, "Only Vorhand can declare trump." + + frame = self._create_history_frame(action) + self.history_frames.append(frame) + + self.trump_suit = action.trump_suit + self.dealer.deal_phase_two(self.players) + self.round_phase = 'play' + self.trick_leader_id = 0 + self.current_player_id = 0 + + def play_card(self, action: PlayCardAction) -> None: + """ + Processes a PlayCardAction and records it in the history. + """ + + assert self.round_phase == 'play', "Cannot play a card outside of the 'play' phase." + + card_to_play = action.card + player = self.players[self.current_player_id] + + column = player.find_column_for_card(card_to_play) + assert column is not None, f"{player} tried to play {card_to_play}, which is not in their layout." + + frame = self._create_history_frame(action) + self.history_frames.append(frame) + + column.play_card() + self.trick_moves.append((self.current_player_id, card_to_play)) + + if len(self.trick_moves) == 1: + self.current_player_id = 1 - self.current_player_id + else: + self._process_trick() + + def _get_card_strength(self, card: BauernskatCard, led_suit: str) -> int: + """ + Determines the strength of a card based on the current trump and started suit. + """ + + STRENGTH_MAP = {'7': 0, '8': 1, '9': 2, 'Q': 3, 'K': 4, '10': 5, 'A': 6, 'J': 7} + if self.trump_suit != 'G': + if card.rank == 'J': + jack_strength = {'C': 3, 'S': 2, 'H': 1, 'D': 0} + return 400 + jack_strength[card.suit] + if card.suit == self.trump_suit: + return 300 + STRENGTH_MAP[card.rank] + if card.suit == led_suit: + STRENGTH_MAP_SUIT = {'7': 0, '8': 1, '9': 2, 'J': 3, 'Q': 4, 'K': 5, '10': 6, 'A': 7} + return 200 + STRENGTH_MAP_SUIT[card.rank] + return 100 + STRENGTH_MAP[card.rank] + + def _determine_trick_winner(self) -> int: + """ + Determines the trick winner. + """ + + assert len(self.trick_moves) == 2, "A trick must have exactly two cards to determine a winner." + + leader_id, leader_card = self.trick_moves[0] + follower_id, follower_card = self.trick_moves[1] + led_suit = leader_card.suit + leader_strength = self._get_card_strength(leader_card, led_suit) + follower_strength = self._get_card_strength(follower_card, led_suit) + + return leader_id if leader_strength > follower_strength else follower_id + + def _process_trick(self) -> None: + """ + Awards pips and updates state. + """ + + winner_id = self._determine_trick_winner() + trick_points = self.trick_moves[0][1].points + self.trick_moves[1][1].points + self.players[winner_id].add_points(trick_points) + self.played_cards.add(self.trick_moves[0][1]) + self.played_cards.add(self.trick_moves[1][1]) + self.trick_moves = [] + self.tricks_played += 1 + self.trick_leader_id = winner_id + self.current_player_id = winner_id + if self.is_over(): + self.round_phase = 'game_over' + + def is_over(self) -> bool: + """ + Checks if the round is over. + """ + + total_tricks = config.NUM_COLUMNS_PER_PLAYER * 2 + return self.tricks_played == total_tricks \ No newline at end of file diff --git a/rlcard/models/__init__.py b/rlcard/models/__init__.py index f772a17ac..ac87d41d5 100644 --- a/rlcard/models/__init__.py +++ b/rlcard/models/__init__.py @@ -29,3 +29,19 @@ register( model_id='gin-rummy-novice-rule', entry_point='rlcard.models.gin_rummy_rule_models:GinRummyNoviceRuleModel') + +register( + model_id='bauernskat-rule-random', + entry_point='rlcard.models.bauernskat_rule_models:BauernskatRandomRuleModelV1') + +register( + model_id='bauernskat-rule-frugal', + entry_point='rlcard.models.bauernskat_rule_models:BauernskatFrugalRuleModelV1') + +register( + model_id='bauernskat-rule-lookahead', + entry_point='rlcard.models.bauernskat_rule_models:BauernskatLookaheadRuleModelV1') + +register( + model_id='bauernskat-rule-shot-alphabeta', + entry_point='rlcard.models.bauernskat_rule_models:BauernskatSHOTAlphaBetaRuleModelV1') \ No newline at end of file diff --git a/rlcard/models/bauernskat_rule_models.py b/rlcard/models/bauernskat_rule_models.py new file mode 100644 index 000000000..3f3641a9d --- /dev/null +++ b/rlcard/models/bauernskat_rule_models.py @@ -0,0 +1,63 @@ +''' + File name: rlcard/models/bauernskat_rule_models.py + Author: Oliver Czerwinski + Date created: 08/15/2025 + Date last modified: 02/16/2026 + Python Version: 3.9+ +''' + +from rlcard.models.model import Model +from rlcard.agents.bauernskat.rule_agents import ( + BauernskatRandomRuleAgent, + BauernskatFrugalRuleAgent, + BauernskatLookaheadRuleAgent, + BauernskatSHOTAlphaBetaRuleAgent +) + +class BauernskatRandomRuleModelV1(Model): + """ + A model that uses the RandomRuleAgent for both players in Bauernskat. + """ + def __init__(self): + """Load rule agent""" + self.rule_agents = [BauernskatRandomRuleAgent() for _ in range(2)] + + @property + def agents(self): + """Get a list of agents for each position in a game.""" + return self.rule_agents + +class BauernskatFrugalRuleModelV1(Model): + """ + A model that uses the FrugalRuleAgent for both players in Bauernskat. + """ + def __init__(self): + """Load rule agent""" + self.rule_agents = [BauernskatFrugalRuleAgent() for _ in range(2)] + + @property + def agents(self): + """Get a list of agents for each position in a game.""" + return self.rule_agents + +class BauernskatLookaheadRuleModelV1(Model): + """A model that uses the LookaheadRuleAgent for both players in Bauernskat.""" + def __init__(self): + """Load rule agent""" + self.rule_agents = [BauernskatLookaheadRuleAgent() for _ in range(2)] + + @property + def agents(self): + """Get a list of agents for each position in a game.""" + return self.rule_agents + +class BauernskatSHOTAlphaBetaRuleModelV1(Model): + """A model that uses the SHOT+AlphaBeta agent for both players in Bauernskat.""" + def __init__(self): + """Load rule agent""" + self.rule_agents = [BauernskatSHOTAlphaBetaRuleAgent() for _ in range(2)] + + @property + def agents(self): + """Get a list of agents for each position in a game.""" + return self.rule_agents \ No newline at end of file diff --git a/tests/envs/test_bauernskat.py b/tests/envs/test_bauernskat.py new file mode 100644 index 000000000..b35415835 --- /dev/null +++ b/tests/envs/test_bauernskat.py @@ -0,0 +1,284 @@ +''' + File name: tests/envs/test_bauernskat_env.py + Author: Oliver Czerwinski + Date created: 08/12/2025 + Date last modified: 12/25/2025 + Python Version: 3.9+ +''' +import unittest +import numpy as np +import random + +import rlcard +from rlcard.agents.bauernskat.rule_agents import BauernskatRandomRuleAgent +from rlcard.games.bauernskat.action_event import DeclareTrumpAction, PlayCardAction +from rlcard.games.bauernskat.card import BauernskatCard +from rlcard.games.bauernskat.player import _CardColumn +from rlcard.games.bauernskat import config + + +class TestBauernskatEnv(unittest.TestCase): + """ + Tests the BauernskatEnv. + """ + + def test_init_and_extract_state(self): + """ + Tests initialization and the structure of the state. + """ + + env = rlcard.make('bauernskat') + state, player_id = env.reset() + + self.assertEqual(player_id, 0) + self.assertIn('obs', state) + self.assertIn('legal_actions', state) + + obs = state['obs'] + + expected_keys = [ + 'my_layout_tensor', 'opponent_layout_tensor', 'unaccounted_cards_mask', + 'trick_card_ids', 'cemetery_card_ids', 'my_hidden_indicators', + 'opponent_hidden_indicators', 'context', 'action_history' + ] + for key in expected_keys: + self.assertIn(key, obs, f"Expected key '{key}' not found in observation.") + + self.assertEqual(obs['my_layout_tensor'].shape, (config.NUM_COLUMNS_PER_PLAYER, 2)) + self.assertEqual(obs['my_layout_tensor'].dtype, np.int32) + self.assertEqual(obs['opponent_layout_tensor'].shape, (config.NUM_COLUMNS_PER_PLAYER, 2)) + self.assertEqual(obs['opponent_layout_tensor'].dtype, np.int32) + self.assertEqual(obs['unaccounted_cards_mask'].shape, (32,)) + self.assertEqual(obs['unaccounted_cards_mask'].dtype, np.float32) + + self.assertEqual(obs['context'].shape, (11,)) + self.assertEqual(obs['my_hidden_indicators'].shape, (config.NUM_COLUMNS_PER_PLAYER,)) + self.assertEqual(obs['action_history'].shape, (config.HISTORY_SEQUENCE_LENGTH, config.HISTORY_FRAME_SIZE)) + self.assertEqual(obs['context'].dtype, np.float32) + + def test_decode_action(self): + """ + Tests that action IDs are decoded into game actions. + """ + + env = rlcard.make('bauernskat') + + decoded_declare = env._decode_action(2) + self.assertIsInstance(decoded_declare, DeclareTrumpAction) + self.assertEqual(decoded_declare.trump_suit, 'H') + + decoded_play = env._decode_action(5) + self.assertIsInstance(decoded_play, PlayCardAction) + self.assertEqual(decoded_play.card.card_id, 0) + + def test_get_legal_actions(self): + """ + Tests the correct set of legal actions. + """ + + env = rlcard.make('bauernskat') + env.reset() + legal_actions = env._get_legal_actions() + + self.assertEqual(len(legal_actions), 5) + self.assertListEqual(list(legal_actions.keys()), [0, 1, 2, 3, 4]) + + def test_get_payoffs_and_scores(self): + """ + Tests the payoff and score after a random game. + """ + + env = rlcard.make('bauernskat') + env.reset() + + while not env.is_over(): + legal_actions = list(env.get_state(env.get_player_id())['legal_actions'].keys()) + action = random.choice(legal_actions) + env.step(action) + + payoffs = env.get_payoffs() + self.assertIsInstance(payoffs, np.ndarray) + self.assertEqual(len(payoffs), 2) + self.assertIsInstance(payoffs[0], np.float32) + self.assertAlmostEqual(payoffs[0], -payoffs[1]) + + def test_run_with_random_agent(self): + """ + Tests a full game with a random agent and the structure of the payoffs. + """ + + env = rlcard.make('bauernskat') + + env.set_agents([ + BauernskatRandomRuleAgent(), + BauernskatRandomRuleAgent(), + ]) + + trajectories, payoffs = env.run(is_training=False) + self.assertEqual(len(trajectories), 2) + self.assertIsInstance(payoffs, np.ndarray) + self.assertEqual(len(payoffs), 2) + + def test_deterministic_run_with_seeded_custom_agent(self): + """ + Tests a full game with a seedable custom rule agent. + """ + + env = rlcard.make('bauernskat') + env.seed(21000) + + agent1 = BauernskatRandomRuleAgent(seed=21001) + agent2 = BauernskatRandomRuleAgent(seed=21002) + + env.set_agents([agent1, agent2]) + + _, payoffs = env.run(is_training=False) + + expected_payoffs = np.array([-100.0, 100.0], dtype=np.float32) + np.testing.assert_array_equal(payoffs, expected_payoffs) + + def test_layout_tensor_mapping(self): + """ + Tests the mapping of the game state layout tensors. + """ + + # Perfect information + env = rlcard.make('bauernskat', config={'information_level': 'perfect'}) + env.reset() + + env.game.players[0].layout = [_CardColumn() for _ in range(config.NUM_COLUMNS_PER_PLAYER)] + + ace_spades = BauernskatCard('S', 'A') + king_clubs = BauernskatCard('C', 'K') + + env.game.players[0].layout[0].open_card = ace_spades + env.game.players[0].layout[3].open_card = None + env.game.players[0].layout[5].open_card = king_clubs + env.game.players[0].layout[5].closed_card = ace_spades + + state = env.get_state(player_id=0) + layout_tensor = state['obs']['my_layout_tensor'] + + self.assertEqual(layout_tensor[0, 0], ace_spades.card_id) + self.assertEqual(layout_tensor[0, 1], 32) + + self.assertEqual(layout_tensor[3, 0], 32) + self.assertEqual(layout_tensor[3, 1], 32) + + self.assertEqual(layout_tensor[5, 0], king_clubs.card_id) + self.assertEqual(layout_tensor[5, 1], ace_spades.card_id) + + def test_perfect_information_mode(self): + """ + Tests the effect of the information level. + """ + + # Normal + env_normal = rlcard.make('bauernskat', config={'seed': 42}) + state_normal, _ = env_normal.reset() + obs_normal = state_normal['obs'] + + self.assertTrue(np.all(obs_normal['opponent_layout_tensor'][:, 1] == 32)) + self.assertEqual(np.sum(obs_normal['unaccounted_cards_mask']), 28.0) + + # Perfect + env_perfect = rlcard.make('bauernskat', config={'seed': 42, 'information_level': 'perfect'}) + state_perfect, _ = env_perfect.reset() + obs_perfect = state_perfect['obs'] + + self.assertTrue(np.any(obs_perfect['opponent_layout_tensor'][:, 1] != 32)) + self.assertEqual(np.sum(obs_perfect['unaccounted_cards_mask']), 20.0) + + def test_extract_state_maps_full_game_state_correctly(self): + """ + Tests a complex game state extraction. + """ + env = rlcard.make('bauernskat', config={'information_level': 'perfect'}) + env.reset() + game = env.game + player0, player1 = game.players + + game.round.round_phase = 'play' + game.round.trump_suit = 'H' + game.round.current_player_id = 1 + game.round.trick_leader_id = 0 + game.round.tricks_played = 5 + player0.score = 35 + player1.score = 25 + + p0_open = BauernskatCard('S', 'A') + p0_hidden = BauernskatCard('C', '7') + p1_open = BauernskatCard('C', 'K') + p1_hidden = BauernskatCard('H', '8') + trick_card = BauernskatCard('D', '10') + cemetery_card = BauernskatCard('D', '7') + + player0.layout = [_CardColumn() for _ in range(8)] + player1.layout = [_CardColumn() for _ in range(8)] + player0.layout[0].open_card = p0_open + player0.layout[2].closed_card = p0_hidden + player1.layout[1].open_card = p1_open + player1.layout[4].closed_card = p1_hidden + game.round.trick_moves = [(0, trick_card)] + game.round.played_cards = {cemetery_card} + + state = env.get_state(player_id=1) + obs = state['obs'] + + self.assertEqual(obs['my_layout_tensor'][1, 0], p1_open.card_id) + self.assertEqual(obs['my_layout_tensor'][4, 1], p1_hidden.card_id) + self.assertEqual(obs['opponent_layout_tensor'][0, 0], p0_open.card_id) + self.assertEqual(obs['opponent_layout_tensor'][2, 1], p0_hidden.card_id) + + known_cards = {p0_open.card_id, p0_hidden.card_id, p1_open.card_id, + p1_hidden.card_id, trick_card.card_id, cemetery_card.card_id} + expected_mask = np.ones(32, dtype=np.float32) + for card_id in known_cards: + expected_mask[card_id] = 0.0 + np.testing.assert_array_equal(obs['unaccounted_cards_mask'], expected_mask) + self.assertEqual(np.sum(obs['unaccounted_cards_mask']), 26.0) + + expected_context = np.array([ + 0., 0., 1., 0., 0., 1.0, 1.0, 0.0, + 25.0 / 480.0, 35.0 / 480.0, 5.0 / 16.0 + ], dtype=np.float32) + np.testing.assert_array_almost_equal(obs['context'], expected_context) + + def test_get_scores(self): + """ + Tests that get_scores() returns the current pips. + """ + + env = rlcard.make('bauernskat') + env.reset() + + env.game.players[0].score = 42 + env.game.players[1].score = 18 + + scores = env.get_scores() + + self.assertIsInstance(scores, np.ndarray) + self.assertEqual(scores.dtype, np.int32) + np.testing.assert_array_equal(scores, [42, 18]) + + def test_asymmetric_information_tensors(self): + """ + Tests asymmetric information levels between players. + """ + + # "perfect" "normal" + env = rlcard.make('bauernskat', config={'information_level': {0: 'perfect', 1: 'normal'}}) + env.reset() + + env.game.players[1].layout[0].closed_card = BauernskatCard('H', 'K') + + state_p0 = env.get_state(0) + opp_tensor_p0 = state_p0['obs']['opponent_layout_tensor'] + self.assertTrue(np.any(opp_tensor_p0[:, 1] != 32)) + + state_p1 = env.get_state(1) + opp_tensor_p1 = state_p1['obs']['opponent_layout_tensor'] + self.assertTrue(np.all(opp_tensor_p1[:, 1] == 32)) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/games/test_bauernskat_game.py b/tests/games/test_bauernskat_game.py new file mode 100644 index 000000000..03e59b7f4 --- /dev/null +++ b/tests/games/test_bauernskat_game.py @@ -0,0 +1,359 @@ +''' + File name: tests/games/test_bauernskat_card.py + Author: Oliver Czerwinski + Date created: 07/29/2025 + Date last modified: 12/25/2025 + Python Version: 3.9+ +''' + +import unittest +import numpy as np + +from rlcard.games.bauernskat import config + +from rlcard.games.bauernskat.card import BauernskatCard +from rlcard.games.bauernskat.action_event import ActionEvent, DeclareTrumpAction, PlayCardAction +from rlcard.games.bauernskat.player import BauernskatPlayer, _CardColumn +from rlcard.games.bauernskat.dealer import BauernskatDealer +from rlcard.games.bauernskat.round import BauernskatRound +from rlcard.games.bauernskat.game import BauernskatGame + +class TestBauernskatGame(unittest.TestCase): + """ + Tests for BauernskatGame. + """ + + def setUp(self): + """ + New BauernskatGame for each test. + """ + self.game = BauernskatGame() + + def test_game_init(self): + """ + Tests the state after initializing the game. + """ + + state, player_id = self.game.init_game() + + self.assertEqual(player_id, 0) + self.assertIn('raw_state_info', state) + self.assertIn('legal_actions', state) + self.assertGreater(len(state['raw_state_info']['my_cards']), 0) + self.assertFalse(self.game.is_over()) + + def test_game_step_and_transitions(self): + """ + Tests the step function and the resulting state transitions. + """ + + state, _ = self.game.init_game() + + declare_action_id = list(state['legal_actions'].keys())[0] + action_event = ActionEvent.from_action_id(declare_action_id) + _, next_player_id = self.game.step(action_event) + + self.assertEqual(self.game.round.round_phase, 'play') + self.assertEqual(next_player_id, 0) + + playable_cards = self.game.players[0].get_playable_cards() + self.assertGreater(len(playable_cards), 0, "Player 0 should have playable cards after trump declaration.") + card_to_play = playable_cards[0] + play_action_object = PlayCardAction(card_to_play) + + _, final_player_id = self.game.step(play_action_object) + + self.assertEqual(final_player_id, 1) + self.assertEqual(len(self.game.round.trick_moves), 1) + + def _create_deterministic_deck(self, p0_jacks, p1_jacks): + """ + Helper function to create a deck with specifically placed Jacks. + """ + + deck = [None] * 32 + all_cards = set(BauernskatCard.get_deck()) + + p0_deal_slots = [31, 30, 29, 28, 23, 22, 21, 20, 15, 14, 13, 12, 7, 6, 5, 4] + p1_deal_slots = [27, 26, 25, 24, 19, 18, 17, 16, 11, 10, 9, 8, 3, 2, 1, 0] + + for jack in p0_jacks: + deck[p0_deal_slots.pop()] = jack + all_cards.remove(jack) + + for jack in p1_jacks: + deck[p1_deal_slots.pop()] = jack + all_cards.remove(jack) + + for i in range(32): + if deck[i] is None: + deck[i] = all_cards.pop() + + return deck + + def test_get_payoffs_with_matadors(self): + """ + Tests the original Skat scoring with with different amount of Matadors. + """ + + self.game.init_game() + + # With 1 + jc, js, jh, jd = [BauernskatCard(s, 'J') for s in ('C', 'S', 'H', 'D')] + fixed_deck = self._create_deterministic_deck(p0_jacks=[jc], p1_jacks=[js, jh, jd]) + + self.game.round.dealer.shuffled_deck = fixed_deck + self.game.round.trump_suit = 'H' # Base value = 10 + self.game.round.tricks_played = config.NUM_COLUMNS_PER_PLAYER * 2 + + # With 1 + self.game.players[0].score = 61 + self.assertEqual(self.game.get_payoffs(), [20.0, -20.0]) + + # With 2 + self.game.players[0].score = 60 + self.assertEqual(self.game.get_payoffs(), [-40.0, 40.0]) + + def test_get_payoffs_without_matadors(self): + """ + Tests the original Skat scoring with without x amount of Matadors. + """ + + self.game.init_game() + + # Without 2 + jc, js, jh, jd = [BauernskatCard(s, 'J') for s in ('C', 'S', 'H', 'D')] + fixed_deck = self._create_deterministic_deck(p0_jacks=[jh, jd], p1_jacks=[jc, js]) + + self.game.round.dealer.shuffled_deck = fixed_deck + self.game.round.trump_suit = 'S' + self.game.round.tricks_played = config.NUM_COLUMNS_PER_PLAYER * 2 + + # Without 2 + self.game.players[0].score = 70 + self.assertEqual(self.game.get_payoffs(), [33.0, -33.0]) + + def test_get_payoffs_grand_game(self): + """ + Tests the original Skat scoring for a Grand game. + """ + + self.game.init_game() + + # With 2 + jc, js, jh, jd = [BauernskatCard(s, 'J') for s in ('C', 'S', 'H', 'D')] + fixed_deck = self._create_deterministic_deck(p0_jacks=[jc, js], p1_jacks=[jh, jd]) + + self.game.round.dealer.shuffled_deck = fixed_deck + self.game.round.trump_suit = 'G' + self.game.round.tricks_played = config.NUM_COLUMNS_PER_PLAYER * 2 + + # With 2 + self.game.players[0].score = 80 + self.assertEqual(self.game.get_payoffs(), [72.0, -72.0]) + + def test_get_payoffs_with_schneider_bonus(self): + """ + Tests the Schneider multiplier. + """ + + self.game.init_game() + + jc, js, jh, jd = [BauernskatCard(s, 'J') for s in ('C', 'S', 'H', 'D')] + fixed_deck = self._create_deterministic_deck(p0_jacks=[jc], p1_jacks=[js, jh, jd]) + + self.game.round.dealer.shuffled_deck = fixed_deck + self.game.round.trump_suit = 'H' + self.game.round.tricks_played = config.NUM_COLUMNS_PER_PLAYER * 2 + + self.game.players[0].score = 91 + + self.assertEqual(self.game.get_payoffs(), [40.0, -40.0]) + + def test_get_payoffs_with_schwarz_bonus(self): + """ + Tests the Schwarz multiplier. + """ + + self.game.init_game() + + jc, js, jh, jd = [BauernskatCard(s, 'J') for s in ('C', 'S', 'H', 'D')] + fixed_deck = self._create_deterministic_deck(p0_jacks=[jh, jd], p1_jacks=[jc, js]) + + self.game.round.dealer.shuffled_deck = fixed_deck + self.game.round.trump_suit = 'S' + self.game.round.tricks_played = config.NUM_COLUMNS_PER_PLAYER * 2 + + self.game.players[0].score = 120 + + self.assertEqual(self.game.get_payoffs(), [132.0, -132.0]) + + def test_get_payoffs_no_schneider_bonus_for_losing_declarer(self): + """ + Tests for no Schneider or Schwarz multiplier when the declarer loses. + """ + + self.game.init_game() + + jc, js, jh, jd = [BauernskatCard(s, 'J') for s in ('C', 'S', 'H', 'D')] + fixed_deck = self._create_deterministic_deck(p0_jacks=[jc], p1_jacks=[js, jh, jd]) + + self.game.round.dealer.shuffled_deck = fixed_deck + self.game.round.trump_suit = 'H' + self.game.round.tricks_played = config.NUM_COLUMNS_PER_PLAYER * 2 + + self.game.players[0].score = 29 + + self.assertEqual(self.game.get_payoffs(), [-40.0, 40.0]) + + def test_get_payoffs_60_60_tie_is_a_loss_for_declarer(self): + """ + Tests the tie rule where it ends in a loss for the declarer. + """ + + self.game.init_game() + + jc, js, jh, jd = [BauernskatCard(s, 'J') for s in ('C', 'S', 'H', 'D')] + + fixed_deck = self._create_deterministic_deck(p0_jacks=[jc], p1_jacks=[js, jh, jd]) + + self.game.round.dealer.shuffled_deck = fixed_deck + self.game.round.trump_suit = 'H' + self.game.round.tricks_played = config.NUM_COLUMNS_PER_PLAYER * 2 + + self.game.players[0].score = 60 + + self.assertEqual(self.game.get_payoffs(), [-40.0, 40.0]) + + def test_get_payoffs_with_all_matadors(self): + """ + Tests the scoring for the edge case of having all four Jacks. + """ + + self.game.init_game() + + # With 4 + jacks = [BauernskatCard(s, 'J') for s in ('C', 'S', 'H', 'D')] + fixed_deck = self._create_deterministic_deck(p0_jacks=jacks, p1_jacks=[]) + + self.game.round.dealer.shuffled_deck = fixed_deck + self.game.round.trump_suit = 'D' + self.game.round.tricks_played = config.NUM_COLUMNS_PER_PLAYER * 2 + + self.game.players[0].score = 75 + + self.assertEqual(self.game.get_payoffs(), [45.0, -45.0]) + + def test_get_payoffs_without_any_matadors(self): + """ + Tests the scoring for the edge case of having none of the four Jacks. + """ + + self.game.init_game() + + # Without 4 + jacks = [BauernskatCard(s, 'J') for s in ('C', 'S', 'H', 'D')] + fixed_deck = self._create_deterministic_deck(p0_jacks=[], p1_jacks=jacks) + + self.game.round.dealer.shuffled_deck = fixed_deck + self.game.round.trump_suit = 'C' + self.game.round.tricks_played = config.NUM_COLUMNS_PER_PLAYER * 2 + + self.game.players[0].score = 65 + + self.assertEqual(self.game.get_payoffs(), [60.0, -60.0]) + + def test_final_state_info(self): + """ + Tests that pip difference and game value are added to the state when the game ends. + """ + + self.game.init_game() + jc, js, jh, jd = [BauernskatCard(s, 'J') for s in ('C', 'S', 'H', 'D')] + fixed_deck = self._create_deterministic_deck(p0_jacks=[jc], p1_jacks=[js, jh, jd]) + self.game.round.dealer.shuffled_deck = fixed_deck + self.game.round.trump_suit = 'H' + self.game.round.tricks_played = config.NUM_COLUMNS_PER_PLAYER * 2 + self.game.players[0].score = 61 + + final_state_p0 = self.game.get_state(0) + final_state_p1 = self.game.get_state(1) + + self.assertIn('pip_difference', final_state_p0['raw_state_info']) + self.assertIn('game_value_payoff', final_state_p0['raw_state_info']) + + self.assertEqual(final_state_p0['raw_state_info']['pip_difference'], 2) # 61-59 + self.assertEqual(final_state_p0['raw_state_info']['game_value_payoff'], 20.0) + + self.assertEqual(final_state_p1['raw_state_info']['pip_difference'], -2) # 59-61 + self.assertEqual(final_state_p1['raw_state_info']['game_value_payoff'], -20.0) + + def test_get_state_information_levels(self): + """ + Tests information levels based on the game config. + """ + + # "normal" "normal" + game_normal = BauernskatGame(information_level='normal') + game_normal.init_game() + raw_info_normal_p0 = game_normal.get_state(0)['raw_state_info'] + raw_info_normal_p1 = game_normal.get_state(1)['raw_state_info'] + + self.assertEqual(raw_info_normal_p0['my_hidden_cards'], []) + self.assertEqual(raw_info_normal_p0['opponent_hidden_cards'], []) + self.assertEqual(raw_info_normal_p1['my_hidden_cards'], []) + self.assertEqual(raw_info_normal_p1['opponent_hidden_cards'], []) + + # "show_self" "show_self" + game_show_self = BauernskatGame(information_level='show_self') + game_show_self.init_game() + raw_info_self_p0 = game_show_self.get_state(0)['raw_state_info'] + raw_info_self_p1 = game_show_self.get_state(1)['raw_state_info'] + + self.assertGreater(len(raw_info_self_p0['my_hidden_cards']), 0) + self.assertEqual(len(raw_info_self_p0['opponent_hidden_cards']), 0) + + self.assertGreater(len(raw_info_self_p1['my_hidden_cards']), 0) + self.assertEqual(len(raw_info_self_p1['opponent_hidden_cards']), 0) + + # "perfect" "perfect" + game_perfect = BauernskatGame(information_level='perfect') + game_perfect.init_game() + raw_info_perfect_p0 = game_perfect.get_state(0)['raw_state_info'] + raw_info_perfect_p1 = game_perfect.get_state(1)['raw_state_info'] + + self.assertGreater(len(raw_info_perfect_p0['my_hidden_cards']), 0) + self.assertGreater(len(raw_info_perfect_p0['opponent_hidden_cards']), 0) + self.assertGreater(len(raw_info_perfect_p1['my_hidden_cards']), 0) + self.assertGreater(len(raw_info_perfect_p1['opponent_hidden_cards']), 0) + + self.assertEqual(raw_info_perfect_p0['opponent_hidden_cards'], raw_info_perfect_p1['my_hidden_cards']) + + # "perfect" "normal" + game_mixed = BauernskatGame(information_level={0: 'perfect', 1: 'normal'}) + game_mixed.init_game() + + raw_info_mixed_p0 = game_mixed.get_state(0)['raw_state_info'] + raw_info_mixed_p1 = game_mixed.get_state(1)['raw_state_info'] + + self.assertGreater(len(raw_info_mixed_p0['my_hidden_cards']), 0) + self.assertGreater(len(raw_info_mixed_p0['opponent_hidden_cards']), 0) + + self.assertEqual(raw_info_mixed_p1['my_hidden_cards'], []) + self.assertEqual(raw_info_mixed_p1['opponent_hidden_cards'], []) + + def test_is_over(self): + """ + Tests the is_over method. + """ + + self.game.init_game() + self.assertFalse(self.game.is_over()) + + total_tricks = config.NUM_COLUMNS_PER_PLAYER * 2 + self.game.round.tricks_played = total_tricks + self.assertTrue(self.game.is_over()) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file