|
1 | 1 | from typing import Optional |
2 | 2 | from quoridor_env import env |
3 | 3 | from agents import Agent |
| 4 | +from agents import AgentRegistry |
| 5 | +from agents.replay import ReplayAgent |
4 | 6 | from dataclasses import dataclass |
5 | 7 | import time |
6 | 8 |
|
@@ -73,14 +75,18 @@ def __init__( |
73 | 75 | step_rewards: bool = False, |
74 | 76 | renderer: Optional[ArenaPlugin] = None, |
75 | 77 | saver: Optional[ArenaPlugin] = None, |
| 78 | + plugins: list[ArenaPlugin] = [], |
76 | 79 | ): |
77 | 80 | self.board_size = board_size |
78 | 81 | self.max_walls = max_walls |
79 | 82 | self.step_rewards = step_rewards |
80 | | - self.game = env(board_size=board_size, max_walls=max_walls, step_rewards=step_rewards) |
| 83 | + self.game = env( |
| 84 | + board_size=board_size, max_walls=max_walls, step_rewards=step_rewards |
| 85 | + ) |
81 | 86 |
|
82 | | - plugins = [p for p in [renderer, saver] if p is not None] |
83 | | - self.plugins = CompositeArenaPlugin(plugins) |
| 87 | + self.plugins = CompositeArenaPlugin( |
| 88 | + [p for p in plugins + [renderer, saver] if p is not None] |
| 89 | + ) |
84 | 90 |
|
85 | 91 | def _play_game(self, agent1: Agent, agent2: Agent, game_id: str) -> GameResult: |
86 | 92 | self.game.reset() |
@@ -128,12 +134,40 @@ def play_games(self, players: list[str], times: int): |
128 | 134 | for i in range(len(players)): |
129 | 135 | for j in range(i + 1, len(players)): |
130 | 136 | for t in range(times): |
131 | | - agent_i = Agent.create(players[i]) |
132 | | - agent_j = Agent.create(players[j]) |
133 | | - agent_1, agent_2 = (agent_i, agent_j) if t % 2 == 0 else (agent_j, agent_i) |
| 137 | + agent_i = AgentRegistry.create(players[i]) |
| 138 | + agent_j = AgentRegistry.create(players[j]) |
| 139 | + agent_1, agent_2 = ( |
| 140 | + (agent_i, agent_j) if t % 2 == 0 else (agent_j, agent_i) |
| 141 | + ) |
134 | 142 |
|
135 | 143 | result = self._play_game(agent_1, agent_2, f"game_{match_id:04d}") |
136 | 144 | results.append(result) |
137 | 145 | match_id += 1 |
138 | 146 |
|
139 | 147 | self.plugins.end_arena(self.game, results) |
| 148 | + |
| 149 | + def replay_games(self, arena_data: dict, game_ids_to_replay: list[str]): |
| 150 | + """Replays a series of games from previously recorded arena data. |
| 151 | +
|
| 152 | + This method simulates games using recorded moves from previous matches, allowing for |
| 153 | + replay and analysis of historical games. |
| 154 | + """ |
| 155 | + self.plugins.start_arena(self.game) |
| 156 | + |
| 157 | + results = [] |
| 158 | + |
| 159 | + if len(game_ids_to_replay) == 0: |
| 160 | + game_ids_to_replay = arena_data["games"].keys() |
| 161 | + |
| 162 | + for game_id in game_ids_to_replay: |
| 163 | + game_data = arena_data["games"][game_id] |
| 164 | + steps_player1 = game_data["actions"][::2] |
| 165 | + steps_player2 = game_data["actions"][1::2] |
| 166 | + |
| 167 | + agent_1 = ReplayAgent(game_data["player1"], steps_player1) |
| 168 | + agent_2 = ReplayAgent(game_data["player2"], steps_player2) |
| 169 | + |
| 170 | + result = self._play_game(agent_1, agent_2, game_id) |
| 171 | + results.append(result) |
| 172 | + |
| 173 | + self.plugins.end_arena(self.game, results) |
0 commit comments