Skip to content

Commit c5ea988

Browse files
author
Julian Cerruti
authored
Merge pull request #110 from adamantivm/dq/deep-q-training
Windsurf-generated DeepQNetwork and training code - Claude Sonnet 3.7
2 parents 481bbdb + 71afd53 commit c5ea988

File tree

10 files changed

+384
-3219
lines changed

10 files changed

+384
-3219
lines changed

.github/workflows/python-app.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,5 @@ jobs:
3232
PYTHONPATH=$PYTHONPATH:$(pwd)/deep_quoridor/src pytest deep_quoridor/test
3333
- name: Run some games as a sanity check
3434
run: |
35-
PYTHONPATH=$PYTHONPATH:$(pwd)/deep_quoridor/src python deep_quoridor/src/main.py -t 2
35+
PYTHONPATH=$PYTHONPATH:$(pwd)/deep_quoridor/src python deep_quoridor/src/play.py -t 2
3636

deep_quoridor/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ gymnasium
33
numpy
44
pytest
55
pyyaml
6-
prettytable
6+
prettytable
7+
torch

deep_quoridor/src/agents/__init__.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ class AgentRegistry:
1515
agents = {}
1616

1717
@staticmethod
18-
def create(friendly_name: str) -> Agent:
19-
return AgentRegistry.agents[friendly_name]()
18+
def create(friendly_name: str, **kwargs) -> Agent:
19+
return AgentRegistry.agents[friendly_name](**kwargs)
2020

2121
@staticmethod
2222
def names():
@@ -44,5 +44,8 @@ def _friendly_name(class_name: str):
4444
return class_name.replace("Agent", "").lower()
4545

4646

47-
from agents.random import RandomAgent # noqa: E402, F401
48-
from agents.simple import SimpleAgent # noqa: E402, F401
47+
__all__ = ["RandomAgent", "SimpleAgent", "Agent", "FlatDQNAgent", "Pretrained01FlatDQNAgent"]
48+
49+
from agents.random import RandomAgent # noqa: E402
50+
from agents.simple import SimpleAgent # noqa: E402
51+
from agents.flat_dqn import FlatDQNAgent, Pretrained01FlatDQNAgent # noqa: E402
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
import numpy as np
2+
import os
3+
import torch
4+
import torch.nn as nn
5+
import torch.optim as optim
6+
from collections import deque
7+
import random
8+
from agents import SelfRegisteringAgent
9+
10+
11+
class DQNNetwork(nn.Module):
12+
"""
13+
Neural network model for Deep Q-learning.
14+
Takes observation from the Quoridor game and outputs Q-values for each action.
15+
"""
16+
17+
def __init__(self, board_size, action_size):
18+
super(DQNNetwork, self).__init__()
19+
20+
# Calculate input dimensions based on observation space
21+
# Board is board_size x board_size with 2 channels (player position and opponent position)
22+
# Walls are (board_size-1) x (board_size-1) with 2 channels (vertical and horizontal walls)
23+
board_input_size = board_size * board_size
24+
walls_input_size = (board_size - 1) * (board_size - 1) * 2
25+
26+
# Additional features: walls remaining for both players
27+
flat_input_size = board_input_size + walls_input_size + 2
28+
29+
# Define network architecture
30+
self.model = nn.Sequential(
31+
nn.Linear(flat_input_size, 256), nn.ReLU(), nn.Linear(256, 128), nn.ReLU(), nn.Linear(128, action_size)
32+
)
33+
34+
def forward(self, x):
35+
return self.model(x)
36+
37+
38+
class ReplayBuffer:
39+
"""
40+
Experience replay buffer to store and sample transitions.
41+
"""
42+
43+
def __init__(self, capacity):
44+
self.buffer = deque(maxlen=capacity)
45+
46+
def add(self, state, action, reward, next_state, done):
47+
self.buffer.append((state, action, reward, next_state, done))
48+
49+
def sample(self, batch_size):
50+
return random.sample(self.buffer, batch_size)
51+
52+
def __len__(self):
53+
return len(self.buffer)
54+
55+
56+
class FlatDQNAgent(SelfRegisteringAgent):
57+
"""
58+
Agent that uses Deep Q-Network for action selection.
59+
"""
60+
61+
def __init__(self, board_size, epsilon=1.0, epsilon_min=0.01, epsilon_decay=0.995, gamma=0.99):
62+
super(FlatDQNAgent, self).__init__()
63+
self.board_size = board_size
64+
# Assumes action representation is a flat array of size board_size**2 + (board_size - 1)**2 * 2
65+
# See quoridor_env.py for details
66+
self.action_size = board_size**2 + (board_size - 1) ** 2 * 2
67+
self.epsilon = epsilon # Exploration rate
68+
self.epsilon_min = epsilon_min
69+
self.epsilon_decay = epsilon_decay
70+
self.gamma = gamma # Discount factor
71+
72+
# Initialize Q-networks (online and target)
73+
self.online_network = DQNNetwork(board_size, self.action_size)
74+
self.target_network = DQNNetwork(board_size, self.action_size)
75+
self.update_target_network()
76+
77+
# Set up optimizer
78+
self.optimizer = optim.Adam(self.online_network.parameters(), lr=0.001)
79+
self.criterion = nn.MSELoss()
80+
81+
# Initialize replay buffer
82+
self.replay_buffer = ReplayBuffer(capacity=10000)
83+
84+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
85+
self.online_network.to(self.device)
86+
self.target_network.to(self.device)
87+
88+
def update_target_network(self):
89+
"""Copy parameters from online network to target network."""
90+
self.target_network.load_state_dict(self.online_network.state_dict())
91+
92+
def preprocess_observation(self, observation):
93+
"""
94+
Convert the observation dict to a flat tensor.
95+
"""
96+
obs = observation["observation"]
97+
board = obs["board"].flatten()
98+
walls = obs["walls"].flatten()
99+
my_walls = np.array([obs["my_walls_remaining"]])
100+
opponent_walls = np.array([obs["opponent_walls_remaining"]])
101+
102+
# Concatenate all components
103+
flat_obs = np.concatenate([board, walls, my_walls, opponent_walls])
104+
return torch.FloatTensor(flat_obs).to(self.device)
105+
106+
def get_action(self, game):
107+
"""
108+
Select an action using epsilon-greedy policy.
109+
"""
110+
observation, _, termination, truncation, _ = game.last()
111+
if termination or truncation:
112+
return None
113+
114+
mask = observation["action_mask"]
115+
valid_actions = np.where(mask == 1)[0]
116+
117+
# With probability epsilon, select a random action (exploration)
118+
if random.random() < self.epsilon:
119+
return np.random.choice(valid_actions)
120+
121+
# Otherwise, select the action with the highest Q-value (exploitation)
122+
state = self.preprocess_observation(observation)
123+
with torch.no_grad():
124+
q_values = self.online_network(state)
125+
126+
# Apply action mask to q_values
127+
mask_tensor = torch.FloatTensor(mask).to(self.device)
128+
q_values = q_values * mask_tensor - 1e9 * (1 - mask_tensor)
129+
130+
return torch.argmax(q_values).item()
131+
132+
def train(self, batch_size):
133+
"""
134+
Train the network on a batch of samples from the replay buffer.
135+
"""
136+
if len(self.replay_buffer) < batch_size:
137+
return
138+
139+
# Sample a batch of transitions
140+
batch = self.replay_buffer.sample(batch_size)
141+
states, actions, rewards, next_states, dones = zip(*batch)
142+
143+
# Convert to tensors
144+
states = torch.stack([torch.FloatTensor(s) for s in states]).to(self.device)
145+
actions = torch.LongTensor(actions).unsqueeze(1).to(self.device)
146+
rewards = torch.FloatTensor(rewards).to(self.device)
147+
next_states = torch.stack([torch.FloatTensor(s) for s in next_states]).to(self.device)
148+
dones = torch.FloatTensor(dones).to(self.device)
149+
150+
# Compute current Q values
151+
current_q_values = self.online_network(states).gather(1, actions).squeeze()
152+
153+
# Compute next Q values using target network
154+
with torch.no_grad():
155+
next_q_values = self.target_network(next_states).max(1)[0]
156+
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
157+
158+
# Compute loss and update online network
159+
loss = self.criterion(current_q_values, target_q_values)
160+
self.optimizer.zero_grad()
161+
loss.backward()
162+
self.optimizer.step()
163+
164+
# Decay epsilon
165+
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
166+
167+
return loss.item()
168+
169+
def save_model(self, path):
170+
"""Save the model to disk."""
171+
torch.save(self.online_network.state_dict(), path)
172+
173+
def load_model(self, path):
174+
"""Load the model from disk."""
175+
self.online_network.load_state_dict(torch.load(path))
176+
self.update_target_network()
177+
178+
179+
class Pretrained01FlatDQNAgent(FlatDQNAgent):
180+
"""
181+
A FlatDQNAgent that is initialized with the pre-trained model from main.py.
182+
"""
183+
184+
def __init__(self, board_size, **kwargs):
185+
super(Pretrained01FlatDQNAgent, self).__init__(board_size)
186+
model_path = "/home/julian/aaae/deep-rabbit-hole/code/deep_rabbit_hole/models/dqn_flat_nostep_final.pt"
187+
if os.path.exists(model_path):
188+
print(f"Loading pre-trained model from {model_path}")
189+
self.load_model(model_path)
190+
else:
191+
print(
192+
f"Warning: Model file {model_path} not found, using untrained agent. Ask Julian for the weights file."
193+
)

deep_quoridor/src/agents/random.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
class RandomAgent(SelfRegisteringAgent):
5-
def __init__(self):
5+
def __init__(self, **kwargs):
66
super().__init__()
77

88
def get_action(self, game):

deep_quoridor/src/agents/simple.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def sample_random_action_sequence(game, max_path_length):
3030

3131

3232
class SimpleAgent(SelfRegisteringAgent):
33-
def __init__(self, sequence_length=3, num_sequences=10):
33+
def __init__(self, sequence_length=3, num_sequences=10, **kwargs):
3434
super().__init__()
3535
self.sequence_length = sequence_length
3636
self.num_sequences = num_sequences
@@ -42,9 +42,7 @@ def get_action(self, game):
4242

4343
possible_action_sequences = []
4444
for _ in range(self.num_sequences):
45-
action_sequence, total_reward = sample_random_action_sequence(
46-
game.copy(), self.sequence_length
47-
)
45+
action_sequence, total_reward = sample_random_action_sequence(game.copy(), self.sequence_length)
4846
possible_action_sequences.append((action_sequence, total_reward))
4947

5048
# Choose the action sequence with the highest reward.

deep_quoridor/src/arena.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,9 @@ def __init__(
8080
self.board_size = board_size
8181
self.max_walls = max_walls
8282
self.step_rewards = step_rewards
83-
self.game = env(
84-
board_size=board_size, max_walls=max_walls, step_rewards=step_rewards
85-
)
83+
self.game = env(board_size=board_size, max_walls=max_walls, step_rewards=step_rewards)
8684

87-
self.plugins = CompositeArenaPlugin(
88-
[p for p in plugins + [renderer, saver] if p is not None]
89-
)
85+
self.plugins = CompositeArenaPlugin([p for p in plugins + [renderer, saver] if p is not None])
9086

9187
def _play_game(self, agent1: Agent, agent2: Agent, game_id: str) -> GameResult:
9288
self.game.reset()
@@ -134,12 +130,9 @@ def play_games(self, players: list[str], times: int):
134130
for i in range(len(players)):
135131
for j in range(i + 1, len(players)):
136132
for t in range(times):
137-
agent_i = AgentRegistry.create(players[i])
138-
agent_j = AgentRegistry.create(players[j])
139-
agent_1, agent_2 = (
140-
(agent_i, agent_j) if t % 2 == 0 else (agent_j, agent_i)
141-
)
142-
133+
agent_i = AgentRegistry.create(players[i], board_size=self.board_size)
134+
agent_j = AgentRegistry.create(players[j], board_size=self.board_size)
135+
agent_1, agent_2 = (agent_i, agent_j) if t % 2 == 0 else (agent_j, agent_i)
143136
result = self._play_game(agent_1, agent_2, f"game_{match_id:04d}")
144137
results.append(result)
145138
match_id += 1
Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,15 @@
77
if __name__ == "__main__":
88
parser = argparse.ArgumentParser(description="Deep Quoridor")
99
parser.add_argument("-N", "--board_size", type=int, default=None, help="Board Size")
10-
parser.add_argument(
11-
"-W", "--max_walls", type=int, default=None, help="Max walls per player"
12-
)
10+
parser.add_argument("-W", "--max_walls", type=int, default=None, help="Max walls per player")
1311
parser.add_argument(
1412
"-r",
1513
"--renderer",
1614
choices=Renderer.names(),
1715
default="results",
1816
help="Render mode",
1917
)
20-
parser.add_argument(
21-
"--step_rewards", action="store_true", default=False, help="Enable step rewards"
22-
)
18+
parser.add_argument("--step_rewards", action="store_true", default=False, help="Enable step rewards")
2319
parser.add_argument(
2420
"-p",
2521
"--players",

0 commit comments

Comments
 (0)