Skip to content

Commit 481bbdb

Browse files
authored
Adding replay tool to allow replaying matches from saved yaml files (#114)
* Adding replay agent and replay tool to allow replaying matches from saved yaml files * Adding delay between moves * Decoupling agent registry from base class, keeping self registration for now
1 parent 9052f30 commit 481bbdb

File tree

9 files changed

+3415
-36
lines changed

9 files changed

+3415
-36
lines changed

deep_quoridor/src/agents/__init__.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,45 @@ class Agent:
44
Given a game state, the agent should return an action.
55
"""
66

7-
agents = {}
7+
def name(self) -> str:
8+
raise NotImplementedError("You must implement the name method")
89

9-
def __init_subclass__(cls, **kwargs):
10-
friendly_name = Agent._friendly_name(cls.__name__)
11-
Agent.agents[friendly_name] = cls
10+
def get_action(self, game) -> int:
11+
raise NotImplementedError("You must implement the get_action method")
1212

13-
def name(self):
14-
return Agent._friendly_name(self.__class__.__name__)
1513

16-
@staticmethod
17-
def _friendly_name(class_name: str):
18-
return class_name.replace("Agent", "").lower()
14+
class AgentRegistry:
15+
agents = {}
1916

2017
@staticmethod
21-
def create(friendly_name: str) -> "Agent":
22-
return Agent.agents[friendly_name]()
18+
def create(friendly_name: str) -> Agent:
19+
return AgentRegistry.agents[friendly_name]()
2320

2421
@staticmethod
2522
def names():
26-
return list(Agent.agents.keys())
23+
return list(AgentRegistry.agents.keys())
2724

28-
def get_action(self, game):
29-
raise NotImplementedError("You must implement the get_action method")
25+
@staticmethod
26+
def register(name: str, agent_class):
27+
AgentRegistry.agents[name] = agent_class
28+
29+
30+
class SelfRegisteringAgent(Agent):
31+
"""
32+
Base class for all agents.
33+
Given a game state, the agent should return an action.
34+
"""
35+
36+
def __init_subclass__(cls, **kwargs):
37+
AgentRegistry.register(SelfRegisteringAgent._friendly_name(cls.__name__), cls)
38+
39+
def name(self):
40+
return SelfRegisteringAgent._friendly_name(self.__class__.__name__)
41+
42+
@staticmethod
43+
def _friendly_name(class_name: str):
44+
return class_name.replace("Agent", "").lower()
3045

3146

32-
from agents.random import RandomAgent
33-
from agents.simple import SimpleAgent
47+
from agents.random import RandomAgent # noqa: E402, F401
48+
from agents.simple import SimpleAgent # noqa: E402, F401

deep_quoridor/src/agents/random.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from agents import Agent
1+
from agents import SelfRegisteringAgent
22

33

4-
class RandomAgent(Agent):
4+
class RandomAgent(SelfRegisteringAgent):
55
def __init__(self):
66
super().__init__()
77

deep_quoridor/src/agents/replay.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from agents import Agent
2+
3+
4+
class ReplayAgent(Agent):
5+
"""A replay agent that plays predefined actions in sequence.
6+
7+
This agent is used for replaying a sequence of actions, typically for testing or
8+
demonstration purposes. It simply returns actions from a predefined list in order.
9+
10+
Args:
11+
actions (list[int]): A list of predefined actions to be played in sequence.
12+
13+
Attributes:
14+
actions (list[int]): The list of predefined actions.
15+
action_index (int): Current index in the actions list.
16+
"""
17+
18+
def __init__(self, name: str, predefined_actions: list[int]):
19+
super().__init__()
20+
self.actions = predefined_actions
21+
self.action_index = 0
22+
self.original_name = name
23+
24+
def get_action(self, game):
25+
action = self.actions[self.action_index]
26+
self.action_index += 1
27+
return action
28+
29+
def name(self):
30+
return f"replay-{self.original_name}"

deep_quoridor/src/agents/simple.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from agents import Agent
1+
from agents import SelfRegisteringAgent
22

33

44
def sample_random_action_sequence(game, max_path_length):
@@ -29,7 +29,7 @@ def sample_random_action_sequence(game, max_path_length):
2929
return action_sequence, total_reward
3030

3131

32-
class SimpleAgent(Agent):
32+
class SimpleAgent(SelfRegisteringAgent):
3333
def __init__(self, sequence_length=3, num_sequences=10):
3434
super().__init__()
3535
self.sequence_length = sequence_length
@@ -42,7 +42,9 @@ def get_action(self, game):
4242

4343
possible_action_sequences = []
4444
for _ in range(self.num_sequences):
45-
action_sequence, total_reward = sample_random_action_sequence(game.copy(), self.sequence_length)
45+
action_sequence, total_reward = sample_random_action_sequence(
46+
game.copy(), self.sequence_length
47+
)
4648
possible_action_sequences.append((action_sequence, total_reward))
4749

4850
# Choose the action sequence with the highest reward.

deep_quoridor/src/arena.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from typing import Optional
22
from quoridor_env import env
33
from agents import Agent
4+
from agents import AgentRegistry
5+
from agents.replay import ReplayAgent
46
from dataclasses import dataclass
57
import time
68

@@ -73,14 +75,18 @@ def __init__(
7375
step_rewards: bool = False,
7476
renderer: Optional[ArenaPlugin] = None,
7577
saver: Optional[ArenaPlugin] = None,
78+
plugins: list[ArenaPlugin] = [],
7679
):
7780
self.board_size = board_size
7881
self.max_walls = max_walls
7982
self.step_rewards = step_rewards
80-
self.game = env(board_size=board_size, max_walls=max_walls, step_rewards=step_rewards)
83+
self.game = env(
84+
board_size=board_size, max_walls=max_walls, step_rewards=step_rewards
85+
)
8186

82-
plugins = [p for p in [renderer, saver] if p is not None]
83-
self.plugins = CompositeArenaPlugin(plugins)
87+
self.plugins = CompositeArenaPlugin(
88+
[p for p in plugins + [renderer, saver] if p is not None]
89+
)
8490

8591
def _play_game(self, agent1: Agent, agent2: Agent, game_id: str) -> GameResult:
8692
self.game.reset()
@@ -128,12 +134,40 @@ def play_games(self, players: list[str], times: int):
128134
for i in range(len(players)):
129135
for j in range(i + 1, len(players)):
130136
for t in range(times):
131-
agent_i = Agent.create(players[i])
132-
agent_j = Agent.create(players[j])
133-
agent_1, agent_2 = (agent_i, agent_j) if t % 2 == 0 else (agent_j, agent_i)
137+
agent_i = AgentRegistry.create(players[i])
138+
agent_j = AgentRegistry.create(players[j])
139+
agent_1, agent_2 = (
140+
(agent_i, agent_j) if t % 2 == 0 else (agent_j, agent_i)
141+
)
134142

135143
result = self._play_game(agent_1, agent_2, f"game_{match_id:04d}")
136144
results.append(result)
137145
match_id += 1
138146

139147
self.plugins.end_arena(self.game, results)
148+
149+
def replay_games(self, arena_data: dict, game_ids_to_replay: list[str]):
150+
"""Replays a series of games from previously recorded arena data.
151+
152+
This method simulates games using recorded moves from previous matches, allowing for
153+
replay and analysis of historical games.
154+
"""
155+
self.plugins.start_arena(self.game)
156+
157+
results = []
158+
159+
if len(game_ids_to_replay) == 0:
160+
game_ids_to_replay = arena_data["games"].keys()
161+
162+
for game_id in game_ids_to_replay:
163+
game_data = arena_data["games"][game_id]
164+
steps_player1 = game_data["actions"][::2]
165+
steps_player2 = game_data["actions"][1::2]
166+
167+
agent_1 = ReplayAgent(game_data["player1"], steps_player1)
168+
agent_2 = ReplayAgent(game_data["player2"], steps_player2)
169+
170+
result = self._play_game(agent_1, agent_2, game_id)
171+
results.append(result)
172+
173+
self.plugins.end_arena(self.game, results)

deep_quoridor/src/arena_yaml_recorder.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from arena import ArenaPlugin, Agent, GameResult
2-
from typing import Optional
32
import yaml
43

54

@@ -35,3 +34,11 @@ def end_arena(self, game, results: list[GameResult]):
3534
}
3635
with open(self.filename, "w") as file:
3736
file.write(yaml.dump(output, sort_keys=False))
37+
38+
@staticmethod
39+
def load_recorded_arena_data(filename: str) -> dict:
40+
with open(filename, "r") as file:
41+
return yaml.load(file, Loader=yaml.FullLoader)
42+
43+
44+

deep_quoridor/src/main.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,45 @@
22
from arena_yaml_recorder import ArenaYAMLRecorder
33
from arena import Arena
44
from renderers import Renderer
5-
from agents import Agent
5+
from agents import AgentRegistry
66

77
if __name__ == "__main__":
88
parser = argparse.ArgumentParser(description="Deep Quoridor")
99
parser.add_argument("-N", "--board_size", type=int, default=None, help="Board Size")
10-
parser.add_argument("-W", "--max_walls", type=int, default=None, help="Max walls per player")
11-
parser.add_argument("-r", "--renderer", choices=Renderer.names(), default="results", help="Render mode")
12-
parser.add_argument("--step_rewards", action="store_true", default=False, help="Enable step rewards")
10+
parser.add_argument(
11+
"-W", "--max_walls", type=int, default=None, help="Max walls per player"
12+
)
13+
parser.add_argument(
14+
"-r",
15+
"--renderer",
16+
choices=Renderer.names(),
17+
default="results",
18+
help="Render mode",
19+
)
20+
parser.add_argument(
21+
"--step_rewards", action="store_true", default=False, help="Enable step rewards"
22+
)
1323
parser.add_argument(
1424
"-p",
1525
"--players",
1626
nargs="+",
17-
choices=Agent.names(),
27+
choices=AgentRegistry.names(),
1828
default=["random", "simple"],
1929
help="List of players to compete against each other",
2030
)
2131
parser.add_argument(
22-
"-A", "--all", action="store_true", default=False, help="Plays a tournament of all agents against each other"
32+
"-A",
33+
"--all",
34+
action="store_true",
35+
default=False,
36+
help="Plays a tournament of all agents against each other",
2337
)
2438
parser.add_argument(
25-
"-t", "--times", type=int, default=10, help="Number of times each player will play with each opponent"
39+
"-t",
40+
"--times",
41+
type=int,
42+
default=10,
43+
help="Number of times each player will play with each opponent",
2644
)
2745
parser.add_argument(
2846
"--games_output_filename",
@@ -39,7 +57,7 @@
3957
if args.games_output_filename != "None":
4058
saver = ArenaYAMLRecorder(args.games_output_filename)
4159

42-
players = Agent.names() if args.all else args.players
60+
players = AgentRegistry.names() if args.all else args.players
4361

4462
arena_args = {
4563
"board_size": args.board_size,

deep_quoridor/src/replay_tool.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import argparse
2+
import time
3+
from arena_yaml_recorder import ArenaYAMLRecorder
4+
from arena import Arena
5+
from arena import ArenaPlugin
6+
from renderers import Renderer
7+
8+
9+
"""Deep Quoridor Game Replay Tool
10+
11+
This script allows replaying recorded Quoridor games from YAML files. It provides command-line
12+
options to customize the replay experience, including renderer selection and specific game filtering.
13+
14+
Command-line Arguments:
15+
-r, --renderer: Render mode for game replay visualization (default: "results")
16+
-t, --time_delay: Time delay between moves in milliseconds. Only applied if > 0 (default: -1)
17+
-g, --game_ids: List of specific game IDs to replay. If not set, replays all games
18+
-f, --games_input_filename: Path to YAML file containing recorded games (default: "game_recording.yaml")
19+
20+
Example Usage:
21+
python replay_tool.py -r text -g game_0008 game_0009 -f my_games.yaml
22+
python replay_tool.py -r text -t 500 -f my_games.yaml # Replay with 500ms delay between moves
23+
"""
24+
25+
26+
class ActionDelayPlugin(ArenaPlugin):
27+
def __init__(self, time_delay: int):
28+
self.time_delay = time_delay
29+
30+
def action(self, game, step, agent, action):
31+
time.sleep(self.time_delay / 1000)
32+
33+
34+
if __name__ == "__main__":
35+
parser = argparse.ArgumentParser(description="Deep Quoridor replay tool")
36+
parser.add_argument(
37+
"-r",
38+
"--renderer",
39+
choices=Renderer.names(),
40+
default="results",
41+
help="Render mode",
42+
)
43+
parser.add_argument(
44+
"-t",
45+
"--time_delay",
46+
type=int,
47+
default=-1,
48+
help="Time delay between moves in ms, > 0 or ignored (default: -1)",
49+
)
50+
parser.add_argument(
51+
"-g",
52+
"--game_ids",
53+
nargs="+",
54+
type=str,
55+
default=[],
56+
help="Game IDs to replay, if not set it will replay all games",
57+
)
58+
parser.add_argument(
59+
"-f",
60+
"--games_input_filename",
61+
type=str,
62+
default="game_recording.yaml",
63+
help="Load the played games from the file",
64+
)
65+
66+
args = parser.parse_args()
67+
68+
renderer = Renderer.create(args.renderer)
69+
70+
arena_data = ArenaYAMLRecorder.load_recorded_arena_data(args.games_input_filename)
71+
72+
arena_args = {
73+
"board_size": arena_data["config"]["board_size"],
74+
"max_walls": arena_data["config"]["max_walls"],
75+
"step_rewards": arena_data["config"]["step_rewards"],
76+
"renderer": renderer,
77+
"plugins": [ActionDelayPlugin(args.time_delay)] if args.time_delay > 0 else [],
78+
}
79+
80+
arena_args = {k: v for k, v in arena_args.items() if v is not None}
81+
arena = Arena(**arena_args)
82+
83+
arena.replay_games(arena_data, args.game_ids)

0 commit comments

Comments
 (0)